In [1]:
import math
import os
import random
import sys
import time

import gzip

import numpy as np

import msprime
import tskit
import tsinfer
from tsinfer import make_ancestors_ts

import cyvcf2
import demes
import demesdraw

print(f"tskit {tskit.__version__}")
print(f"tsinfer {tsinfer.__version__}")
print(f"msprime {msprime.__version__}")
print(f"demes {demes.__version__}")
print(f"cyvcf2 {cyvcf2.__version__}")

tskit 0.5.2.dev0
tsinfer 0.2.3
msprime 1.1.1
demes 0.2.1
cyvcf2 0.30.14


### Helper functions

In [18]:
def count_sites_by_type(ts_or_sd):
    """
    Take a tree sequence or sample data,
    because those two types of objects have a .variant() function.
    """
    sites_mono = 0
    sites_bi = 0
    sites_bi_singleton = 0
    sites_tri = 0
    sites_quad = 0
    
    for v in ts_or_sd.variants():
        num_alleles = len(set(v.alleles) - {None})
        if num_alleles == 1:
            sites_mono += 1
        elif num_alleles == 2:
            sites_bi += 1
            if np.sum(v.genotypes) == 1:
                sites_bi_singleton += 1
        elif num_alleles == 3:
            sites_tri += 1
        else:
            sites_quad += 1
            
    print(f"\tsites mono : {sites_mono}")
    print(f"\tsites bi   : {sites_bi} ({sites_bi_singleton} singletons)")
    print(f"\tsites tri  : {sites_tri}")
    print(f"\tsites quad : {sites_quad}")
    print(f"\tsites total: {sites_mono + sites_bi + sites_tri + sites_quad}")
    
    return None

### Simulation parameters

In [3]:
replicate_index = 0

sampling_time_query = 100

proportion_missing_sites = 0.80

### Simulate genealogy and genetic variation

In [4]:
# For simulations
# size_ref   = 1e4
# size_query = 1e3

# eff_pop_size = 10_000
# mutation_rate = 1e-8
# recombination_rate = 1e-8

# contig_id = '1'
# ploidy_level = 1
# sequence_length = 1_000_000

In [5]:
# For testing
size_ref   = 5
size_query = 5

eff_pop_size = 10_000
mutation_rate = 1e-6
recombination_rate = 1e-7

contig_id = '1'
ploidy_level = 1
sequence_length = 10_000

In [6]:
# Uniform recombination rate
recomb_rate_map = msprime.RateMap.uniform(
    sequence_length=sequence_length,
    rate=recombination_rate,
)

# Uniform mutation rate
mut_rate_map = msprime.RateMap.uniform(
    sequence_length=sequence_length,
    rate=mutation_rate,
)

In [7]:
sample_set = [
    # Reference genomes
    msprime.SampleSet(num_samples=size_ref,
                      time=0,
                      ploidy=ploidy_level),
    # Query genomes
    msprime.SampleSet(num_samples=size_query,
                      time=sampling_time_query,
                      ploidy=ploidy_level),
]

In [19]:
# A simulated tree sequence does not contain any monoallelic sites,
# but there may be multiallelic sites.
ts_full = msprime.sim_mutations(
    msprime.sim_ancestry(
        samples=sample_set,
        population_size=eff_pop_size,
        model="hudson",
        recombination_rate=recomb_rate_map,
        discrete_genome=True,
        random_seed=123,
    ),
    rate=mut_rate_map,
    discrete_genome=True,
    random_seed=123,
)

# Remove populations
tables = ts_full.dump_tables()
tables.populations.clear()
tables.nodes.population = np.full_like(tables.nodes.population, tskit.NULL)
ts_full = tables.tree_sequence()

print("TS full")
count_sites_by_type(ts_full)

TS full
	sites mono : 0
	sites bi   : 996 (376 singletons)
	sites tri  : 47
	sites quad : 0
	sites total: 1043


In [9]:
individuals_ref = np.arange(size_ref, dtype=int)
samples_ref = np.arange(ploidy_level * size_ref, dtype=int)

individuals_query = np.arange(size_ref, size_ref + size_query, dtype=int)
samples_query = np.arange(ploidy_level * size_ref, ploidy_level * (size_ref + size_query), dtype=int)

### Create an ancestor ts from the reference genomes

In [20]:
# Remove all the branches leading to the query genomes
ts_ref = ts_full.simplify(samples_ref, filter_sites=False)

print(f"TS ref has {ts_ref.num_samples} sample genomes ({ts_ref.sequence_length} bp)")
print(f"TS ref has {ts_ref.num_sites} sites and {ts_ref.num_trees} trees")
print("TS ref")
count_sites_by_type(ts_ref)

TS ref has 5 sample genomes (10000.0 bp)
TS ref has 1043 sites and 80 trees
TS ref
	sites mono : 180
	sites bi   : 826 (356 singletons)
	sites tri  : 37
	sites quad : 0
	sites total: 1043


In [21]:
# Multiallelic sites are automatically removed when generating an ancestor ts.
# Sites which are biallelic in the full sample set but monoallelic in the ref. sample set are removed.
# So, only biallelic sites are retained in the ancestor ts.
ts_anc = make_ancestors_ts(ts=ts_ref, remove_leaves=True, samples=None)

print(f"TS anc has {ts_anc.num_samples} sample genomes ({ts_anc.sequence_length} bp)")
print(f"TS anc has {ts_anc.num_sites} sites and {ts_anc.num_trees} trees")
print("TS anc")
count_sites_by_type(ts_anc)

TS anc has 59 sample genomes (10000.0 bp)
TS anc has 455 sites and 65 trees
TS anc
	sites mono : 0
	sites bi   : 455 (0 singletons)
	sites tri  : 0
	sites quad : 0
	sites total: 455


### Create a SampleData object holding the query genomes 

In [22]:
sd_full = tsinfer.SampleData.from_tree_sequence(ts_full)
sd_query = sd_full.subset(individuals_query)

print(f"SD query has {sd_query.num_samples} sample genomes ({sd_query.sequence_length} bp)")
print(f"SD query has {sd_query.num_sites} sites")
print("SD query")
count_sites_by_type(sd_query)

SD query has 5 sample genomes (10000.0 bp)
SD query has 1043 sites
SD query
	sites mono : 0
	sites bi   : 996 (202 singletons)
	sites tri  : 47
	sites quad : 0
	sites total: 1043


In [40]:
def make_compatible_sample_data(sample_data, ancestors_ts):
    """
    Make an editable copy of SD query, and edit it so that:
    (1) the derived alleles in the original SD query not found in TS anc are marked as MISSING; and
    (2) the allele list in the new SD query corresponds to the allele list in TS anc.
    
    N.B.
      Two SampleData attributes sites_alleles and sites_genotypes,
      which are not explained in the tsinfer API doc,
      are used to facilitate the editing.
    """
    
    new_sample_data = sample_data.copy()
    sd_variants = sample_data.variants()
    
    # Iterate through the sites in the TS anc using one generator,
    # while iterating through the sites in the SD query using another generator,
    # letting the latter generator catch up.
    sd_var = next(sd_variants)
    for ref_site in ancestors_ts.sites():
        while sd_var.site.position != ref_site.position:
            # Sites found in SD query but not in TS anc are not imputed.
            # Also, leave them as is in the SD query.
            sd_var = next(sd_variants)
            
        site_id = sd_var.site.id # Site id in SD query
        assert len(ref_site.alleles) == 2 # Must be biallelic in TS anc
        
        # Get the derived allele in TS anc in nucleotide space
        derived_allele = ref_site.alleles - {ref_site.ancestral_state}
        #assert len(derived_allele) == 1
        derived_allele = tuple(derived_allele)[0]

        # Ancestral allele should be the same in TS anc and SD query,
        # and it is at index 0 in allele list in both.
        assert ref_site.ancestral_state == sd_var.alleles[0]
        
        if derived_allele not in sd_var.alleles:
            # Case 1:
            #   If the derived alleles in the SD query are not found in TS anc,
            #   then mark them as missing.
            # 
            #   We cannot determine whether the extra derived alleles in SD query
            #   are derived from 0 or 1 in TS anc anyway.
            # 
            new_sample_data.sites_genotypes[site_id] = np.where(
                sd_var.genotypes != 0, # Keep if ancestral
                tskit.MISSING_DATA, # Otherwise, flag as missing
                0,
            )
            print("Site", site_id, "has no matching derived alleles in the query set")
            
            # Make sure the allele list in SD query reflects above.
            new_sample_data.sites_alleles[site_id] = [ref_site.ancestral_state]
        else:
            # The allele lists in TS anc and SD query may be different.
            derived_allele_index = sd_var.alleles.index(derived_allele)
            
            if derived_allele_index == 1:
                # Case 2:
                #   Both the ancestral and derived alleles correspond exactly
                #   in TS anc and SD query.
                if len(sd_var.alleles) == 2:
                    continue
                    
                # Case 3:
                #   The derived allele in TS anc is indexed as 1 in SD query,
                #   so mark alleles >= 2 as missing.
                new_sample_data.sites_genotypes[site_id] = np.where(
                    sd_var.genotypes > 1, # Anything 0 or 1 should be retained as is
                    tskit.MISSING_DATA, # Otherwise, flag as missing
                    sd_var.genotypes,
                )
                print("Site", site_id, "has extra derived alleles in the query set, which have been set as missing")
            else:
                # Case 4:
                #   The derived allele in TS anc is NOT indexed as 1 in SD query,
                #   so the alleles in SD query needs to be reordered,
                #   such that the indexed-1 allele is also indexed as 1 in TS anc.
                new_sample_data.sites_genotypes[site_id] = np.where(
                    sd_var.genotypes == 0, # Leave ancestral allele as is
                    0,
                    np.where(
                        sd_var.genotypes == derived_allele_index,
                        1, # Change it to 1 so that it corresponds to TS anc
                        tskit.MISSING_DATA # Otherwise, it is not found in TS anc so mark as missing
                    )
                )
                print("Site", site_id, "has the target derived allele at a different index")
                
            # Make sure the allele list in SD query reflects above.
            new_sample_data.sites_alleles[site_id] = [ref_site.ancestral_state, derived_allele]
            
    new_sample_data.finalise()
    
    return(new_sample_data)

In [41]:
new_sd_query = make_compatible_sample_data(
    sample_data=sd_query,
    ancestors_ts=ts_anc,
)

Site 57 has extra derived alleles in the query set, which have been set as missing
Site 569 has extra derived alleles in the query set, which have been set as missing
Site 663 has extra derived alleles in the query set, which have been set as missing
Site 834 has the target derived allele at a different index


In [42]:
# Check that matching works.
ts_matched = tsinfer.match_samples(
    sample_data=new_sd_query,
    ancestors_ts=ts_anc,
)

### Create a SampleData object with masked sites

In [45]:
def sample_sites(sample_data, proportion_missing_sites):
    """
    Draw N sites stored in a SampleData object at random
    by site id, where N is the number of sites to mask
    based on a specified proportion of missing sites.
    
    TODO: Specify random seed.
    
    :param sample_data: SampleData object
    :param proportion_missing_sites: float between 0 and 1
    :return: list of site ids
    """
    assert proportion_missing_sites >= 0 and proportion_missing_sites <= 1, \
        "Proportion of missing sites must be between 0 and 1."
    
    num_masked_sites = int(np.floor(sample_data.num_sites * proportion_missing_sites))
    
    masked_site_ids = np.sort(
        np.random.choice(
            np.arange(sample_data.num_sites),
            num_masked_sites,
            replace=False,
        )
    )
    
    return(masked_site_ids)

In [46]:
def mask_sites_in_sample_data(sample_data, masked_sites=None):
    """
    Create and return a SampleData object from an existing SampleData object,
    which contain masked sites as listed in masked_sites (site ids).
    
    :param sample_data: SampleData object
    :param masked_sites: list of site ids (NOT positions)
    :return: SampleData object
    """
    new_sample_data = sample_data.copy()
    
    for v in sample_data.variants():
        if v.site.id in masked_sites:
            new_sample_data.sites_genotypes[v.site.id] = np.full_like(v.genotypes, tskit.MISSING_DATA)
    
    new_sample_data.finalise()
    
    return(new_sample_data)

In [47]:
masked_site_ids = sample_sites(sd_query, proportion_missing_sites)

sd_query_masked = mask_sites_in_sample_data(
    new_sd_query,
    masked_sites=masked_site_ids,
)

### Impute the query genomes

In [48]:
ts_imputed = tsinfer.match_samples(
    sample_data=sd_query_masked,
    ancestors_ts=ts_anc,
)

### Write VCF

In [21]:
ref_vcf_file = ".".join(["ref", str(replicate_index), "vcf", "gz"])

query_true_vcf_file = ".".join(["query_true", str(replicate_index), "vcf", "gz"])
query_miss_vcf_file = ".".join(["query_miss", str(replicate_index), "vcf", "gz"])

imputed_vcf_file = ".".join(["imputed", str(replicate_index), "vcf", "gz"])

In [22]:
with gzip.open(ref_vcf_file, "wt") as f:
    ts_ref.write_vcf(f)

In [36]:
# TODO:
#   Write VCF from SD query
#   Write VCF from SD query masked

In [29]:
with gzip.open(imputed_vcf_file, "wt") as f:
    ts_imputed.write_vcf(f)