In [56]:
import math
import os
import random
import sys
import time

import gzip

import numpy as np

import tskit
from tsinfer import make_ancestors_ts
import tsinfer
import msprime

import demes
import demesdraw
import cyvcf2

print(f"tskit {tskit.__version__}")
print(f"tsinfer {tsinfer.__version__}")
print(f"msprime {msprime.__version__}")
print(f"demes {demes.__version__}")
print(f"cyvcf2 {cyvcf2.__version__}")

tskit 0.5.0
tsinfer 0.2.3
msprime 1.1.1
demes 0.2.1
cyvcf2 0.30.14


In [37]:
# Sourced and modified from:
# https://tsinfer.readthedocs.io/en/latest/tutorial.html#data-example
def get_chromosome_length(vcf):
    assert len(vcf.seqlens) == 1
    return vcf.seqlens[0]


def add_populations(vcf,
                    samples):
    """
    TODO
    """
    pop_ids = [sample_name[0] for sample_name in vcf.samples]
    pop_codes = np.unique(pop_ids)
    pop_lookup = {}
    for p in pop_codes:
        pop_lookup[p] = samples.add_population(metadata = {"name" : p})
    return [pop_lookup[pop_id] for pop_id in pop_ids]


def add_individuals(vcf,
                    samples,
                    ploidy_level,
                    populations):
    for name, population in zip(vcf.samples, populations):
        samples.add_individual(ploidy = ploidy_level,
                               metadata = {"name": name},
                               population = population)


def add_sites(vcf,
              samples,
              ploidy_level,
              warn_monomorphic_sites = False):
    """
    Read the sites in the VCF and add them to the samples object,
    reordering the alleles to put the ancestral allele first,
    if it is available.
    """
    assert ploidy_level == 1 or ploidy_level == 2,\
        f"ploidy_level {ploidy_level} is not recognized."
    
    pos = 0
    for variant in vcf:
        # Check for duplicate site positions.
        if pos == variant.POS:
            raise ValueError("Duplicate positions for variant at position", pos)
        else:
            pos = variant.POS
        # Check that the genotypes are phased.
        #if any([not phased for _, _, phased in variant.genotypes]):
        #    raise ValueError("Unphased genotypes for variant at position", pos)
        alleles = [variant.REF] + variant.ALT # Exactly as in the input VCF file.
        if warn_monomorphic_sites:
            if len(set(alleles) - {'.'}) == 1:
                print(f"Monomorphic site at {pos}")
        ancestral = variant.INFO.get("AA", variant.REF) # Dangerous action!!!
        # Ancestral state must be first in the allele list.
        ordered_alleles = [ancestral] + list(set(alleles) - {ancestral})
        # Create an index mapping from the input VCF to tsinfer input.
        allele_index = {
            old_index: ordered_alleles.index(allele)
            for old_index, allele in enumerate(alleles)
        }
        # When genotype is missing...
        if variant.num_unknown > 0:
            allele_index[-1] = tskit.MISSING_DATA
            ordered_alleles += [None]
        # Map original allele indexes to their indexes in the new alleles list.
        genotypes = [
            allele_index[old_index]
            for row in variant.genotypes # cyvcf2 uses -1 to indicate missing data.
            for old_index in row[0:ploidy_level] # Each is a 3-tuple (allele 1, allele 2, is phased?).
        ]
        samples.add_site(pos,
                         genotypes = genotypes,
                         alleles = ordered_alleles)


def create_sample_data_from_vcf_file(vcf_file):
    vcf = cyvcf2.VCF(vcf_file,
                     gts012 = False, # 0=HOM_REF, 1=HET, 2=UNKNOWN, 3=HOM_ALT
                     strict_gt = True)
    with tsinfer.SampleData(
        sequence_length = get_chromosome_length(vcf)
    ) as samples:
        populations = add_populations(vcf, samples)
        add_individuals(vcf, samples, ploidy_level, populations)
        add_sites(vcf, samples, ploidy_level)
    return(samples)

In [38]:
def find_biallelic_sites(sample_data_1, sample_data_2):
    variants_1 = sample_data_1.variants()
    variants_2 = sample_data_2.variants()
    # Keep only biallelic sites
    sites_1 = []
    sites_2 = []
    for var_1, var_2 in zip(variants_1, variants_2):
        assert var_1.site.position == var_2.site.position
        alleles_1 = set(var_1.alleles) - {None}
        alleles_2 = set(var_2.alleles) - {None}
        if len(alleles_1) == 2\
            and len(alleles_2) == 2\
            and alleles_1 == alleles_2:
            sites_1.append(var_1.site.id)
            sites_2.append(var_2.site.id)
    assert len(sites_1) == len(sites_2),\
        "The number of site positions in sites_1 and sites_2 are different."
    return(sites_1, sites_2)

In [39]:
def get_ts_with_discretized_coordinates(ts):
    ts_tables = ts.dump_tables()
    ts_tables.sites.position = np.round(ts_tables.sites.position)
    ts_tables.deduplicate_sites()
    ts_tables.sort()
    ts_tables.build_index()
    ts_tables.compute_mutation_times()
    ts_discretized = ts_tables.tree_sequence()
    return(ts_discretized)

In [40]:
def get_random_site_mask(ts, missing):
    assert missing >=0 and missing <= 1,\
        "Proportion of missing sites is not between 0 and 1."
    site_mask = np.random.random(ts.num_sites) < missing
    return(site_mask)

In [41]:
def convert_into_ancestor_tree_sequence(ts, samples):
    """
    Remove the tips (or the sample nodes at time 0) from a tree sequence,
    and return an ancestor tree sequence.
    
    Presently, there is an extra step to remove the metadata from the
    ancestor tree sequence.
    """
    ts_tipless = make_ancestors_ts(samples = samples,
                                   ts = ts,
                                   remove_leaves = True)
    tmp_tables = ts_tipless.dump_tables()
    tmp_tables.populations.metadata_schema = tskit.MetadataSchema(schema = None)
    ts_new = tmp_tables.tree_sequence()
    return(ts_new)

In [42]:
def impute_genotypes_using_tsinfer(ref_vcf_file,
                                   miss_vcf_file,
                                   imputed_vcf_file,
                                   contig_id):
    sd_ref  = create_sample_data_from_vcf_file(ref_vcf_file)
    sd_miss = create_sample_data_from_vcf_file(miss_vcf_file)
    ad_ref     = tsinfer.generate_ancestors(sample_data = sd_ref)
    # This step is to infer a tree sequence from the sample data.
    ts_anc_ref = tsinfer.match_ancestors(sample_data   = sd_ref,
                                         ancestor_data = ad_ref)
    ts_matched = tsinfer.match_samples(sample_data  = sd_miss,
                                       ancestors_ts = ts_anc_ref)
    with open(imputed_vcf_file, "w") as vcf:
        ts_matched.write_vcf(vcf, contig_id=contig_id)
    return(ts_matched)

In [67]:
def impute_genotypes_using_ts_only(ref_vcf_file,
                                   miss_vcf_file,
                                   imputed_vcf_file,
                                   imputed_ts_file,
                                   ts_anc_ref,
                                   contig_id):
    sd_ref = create_sample_data_from_vcf_file(ref_vcf_file)
    sd_miss = create_sample_data_from_vcf_file(miss_vcf_file)
    
    # Clean ts_anc_ref
    tmp_tables = ts_anc_ref.dump_tables()
    #tmp_tables.individuals.clear()
    ts_anc_ref = tmp_tables.tree_sequence()
    
    ts_fixed = tsinfer.match_samples(sample_data=sd_miss,
                                     ancestors_ts=ts_anc_ref)
    with gzip.open(imputed_vcf_file, "wt") as f:
        ts_fixed.write_vcf(f, contig_id=contig_id)
    ts_fixed.dump(imputed_ts_file)
    
    return(ts_fixed)

## Create data sets via simulations.

In [46]:
base_dir = "../data/ancient_panmictic_haploid_miss80_time1e2/"

sampling_time_query = 100

num_replicates = 10

size_query = 100
size_ref   = 1_000

eff_pop_size = 10_000
mutation_rate = 1e-8
recombination_rate = 1e-8

proportion_missing_sites = 0.80

contig_id = '1'
ploidy_level = 1
sequence_length = 1_0000_000 # 10 Mbp

print(f"Size of the reference panel is {size_ref}")
print(f"Size of the query is {size_query}")
print(f"Ploidy level is {ploidy_level}")
print(f"Population size is {eff_pop_size}")
print(f"Sampling time query : {sampling_time_query}")
print(f"Base directory : {base_dir}")

Size of the reference panel is 1000
Size of the query is 100
Ploidy level is 1
Population size is 10000
Sampling time query : 100
Base directory : ../data/ancient_panmictic_haploid_miss80_time1e2/


In [47]:
rate_map = msprime.RateMap.uniform(
    sequence_length = sequence_length,
    rate = recombination_rate
)

In [48]:
sample_set = [
    msprime.SampleSet(num_samples = size_query,
                      time = sampling_time_query,
                      ploidy = ploidy_level),
    msprime.SampleSet(num_samples = size_ref,
                      time = 0,
                      ploidy = ploidy_level)
]

In [49]:
src_ts = [] # List of full ts.
anc_ts = [] # List of ancestor ts.

In [50]:
print(f"Simulating {num_replicates} tree sequences.")

tic = time.time()

for i in np.arange(num_replicates):
    sim_ts = msprime.sim_mutations(
        msprime.sim_ancestry(
            samples = sample_set,
            population_size = eff_pop_size,
            model = "hudson",
            recombination_rate = rate_map,
            discrete_genome = True
        ),
        rate = mutation_rate,
        discrete_genome = True
    )
    src_ts.append(sim_ts)
    
toc = time.time()
print(f"Simulation of {num_replicates} ts took {round(toc - tic, 2)} seconds.")

Simulating 10 tree sequences.
Simulation of 10 ts took 5.41 seconds.


In [62]:
individuals_query = np.arange(size_query, dtype = int)
individual_names_query = ["query_" + str(i) for i in individuals_query]
samples_query = np.arange(ploidy_level * size_query, dtype=int)

individuals_ref = np.arange(size_query, size_query + size_ref, dtype=int)
individual_names_ref = ["ref_" + str(i) for i in individuals_ref]
samples_ref = np.arange(ploidy_level * size_query, ploidy_level * (size_query + size_ref), dtype=int)

In [64]:
for i, ts in enumerate(src_ts[:1]):
    print(f"Processing ts {i}.")
    tic = time.time()
    
    ref_vcf_file = base_dir + "ref/"  + "ref."  + str(i) + ".vcf.gz"
    true_vcf_file = base_dir + "true/" + "true." + str(i) + ".vcf.gz"
    miss_vcf_file = base_dir + "miss/" + "miss." + str(i) + ".vcf.gz"
    
    ts_full_ref_file = base_dir + "ref/" + "ts_full_ref." + str(i) + ".trees"
    ts_anc_ref_file = base_dir + "ts_anc_ref/" + "ts_anc_ref." + str(i) + ".trees"
    
    print("\tGetting ancestors ts...")
    ts_anc_ref = convert_into_ancestor_tree_sequence(ts, samples=samples_ref)
    anc_ts.append(ts_anc_ref)
    
    ts.dump(ts_full_ref_file)
    ts_anc_ref.dump(ts_anc_ref_file)
    
    site_mask = get_random_site_mask(ts, missing=proportion_missing_sites)
    
    print("\tPrinting reference VCF...")
    with gzip.open(ref_vcf_file, "wt") as f:
        ts.write_vcf(f, individuals=individuals_ref)
    print("\tPrinting query VCF with non-missing genotypes...")
    with gzip.open(true_vcf_file, "wt") as f:
        ts.write_vcf(f, individuals=individuals_query)
    print("\tPrinting query VCF with missing genotypes...")
    with gzip.open(miss_vcf_file, "wt") as f:
        ts.write_vcf(f,
                     individuals=individuals_query,
                     site_mask=site_mask)
    
    toc = time.time()
    print(f"Took {toc - tic} seconds to process ts {i}.")

Processing ts 0.
	Getting ancestors ts...
	Printing reference VCF...
	Printing query VCF with non-missing genotypes...
	Printing query VCF with missing genotypes...
Took 15.67206883430481 seconds to process ts 0.


## Perform genotype imputation.

In [68]:
print("Doing imputation using ts only.")

for i in np.arange(len(src_ts[:1])):
    print(f"Imputing VCF {i}")
    tic = time.time()
    
    ref_vcf_file     = base_dir + "ref/"  + "ref."  + str(i) + ".vcf.gz"
    miss_vcf_file    = base_dir + "miss/" + "miss." + str(i) + ".vcf.gz"
    imputed_vcf_file = base_dir + "imputed_tsonly/" + "imputed." + str(i) + ".vcf.gz"
    ts_imputed_file  = base_dir + "imputed_tsonly/" + "imputed." + str(i) + ".trees"
    ts_imputed       = impute_genotypes_using_ts_only(ref_vcf_file = ref_vcf_file,
                                                      miss_vcf_file = miss_vcf_file,
                                                      imputed_vcf_file = imputed_vcf_file,
                                                      imputed_ts_file = ts_imputed_file,
                                                      ts_anc_ref = anc_ts[i],
                                                      contig_id = contig_id)
    
    toc = time.time()
    print(f"Took {toc - tic} seconds to process ts {i}.")

Doing imputation using ts only.
Imputing VCF 0


IndexError: index 6274 is out of bounds for axis 0 with size 6274

In [10]:
print("Doing imputation using tsinfer.")

for i in np.arange(len(src_ts[:1])):
    print(f"Imputing VCF {i}")
    tic = time.time()
    
    ref_vcf_file     = base_dir + "ref/" + "ref."  + str(i) + ".vcf"
    miss_vcf_file    = base_dir + "miss/" + "miss." + str(i) + ".vcf"
    imputed_vcf_file = base_dir + "imputed_tsinfer/" + "imputed." + str(i) + ".vcf"
    ts_imputed_file  = base_dir + "imputed_tsinfer/" + "imputed." + str(i) + ".trees"
    ts_imputed       = impute_genotypes_using_tsinfer(ref_vcf_file = ref_vcf_file,
                                                      miss_vcf_file = miss_vcf_file,
                                                      imputed_vcf_file = imputed_vcf_file,
                                                      contig_id = contig_id)
    ts_imputed.dump(ts_imputed_file)
    
    toc = time.time()
    print(f"Took {toc - tic} seconds to process ts {i}.")

Doing imputation using tsinfer.
Imputing VCF 0


In [12]:
print("Doing imputation using BEAGLE.")

beagle_exe = "../analysis/beagle/beagle.28Jun21.220.jar"

#for i in np.arange(len(src_ts)):
for i in [0]:
    ref_vcf_file     = base_dir + "ref/"  + "ref."  + str(i) + ".vcf"
    miss_vcf_file    = base_dir + "miss/" + "miss." + str(i) + ".vcf"
    imputed_vcf_file = base_dir + "imputed_beagle/" + "imputed." + str(i)
    beagle_cmd = [
        "java", "-jar", beagle_exe,
        "ref=" + ref_vcf_file,
        "gt="  + miss_vcf_file,
        "out=" + imputed_vcf_file
    ]
    beagle_cmd = " ".join(beagle_cmd)
    print(beagle_cmd + "\n")

Doing imputation using BEAGLE.
java -jar ../analysis/beagle/beagle.28Jun21.220.jar ref=../data/ancient_panmictic_haploid_miss80/ref/ref.0.vcf gt=../data/ancient_panmictic_haploid_miss80/miss/miss.0.vcf out=../data/ancient_panmictic_haploid_miss80/imputed_beagle/imputed.0

