In [1]:
from copy import deepcopy
import math
import numpy as np
import os
import random
import sys
import time

from IPython.display import SVG

sys.path.append("../modules/")
import mask_genotype
import parse_vcf

import demes
import tsinfer
from tsinfer import make_ancestors_ts
import tskit
import msprime
import cyvcf2

print(f"tskit {tskit.__version__}")
print(f"tsinfer {tsinfer.__version__}")
print(f"msprime {msprime.__version__}")
print(f"stdpopsim {stdpopsim.__version__}")
print(f"cyvcf2 {cyvcf2.__version__}")

tskit 0.4.1
tsinfer 0.2.3.dev9+gc8568d5
msprime 1.1.1
stdpopsim 0.1.2
cyvcf2 0.30.14


In [2]:
def print_sample_data_to_vcf(sample_data,
                             individuals,
                             samples,
                             ploidy_level,
                             mask,
                             out_vcf_file,
                             contig_id,
                             sequence_length_max = 1e12):
    """
    Fields:
    CHROM contig_id
    POS row index in genotype_matrix
    ID .
    REF ancestral allele
    ALT derived allele(s)
    QUAL .
    FILTER PASS
    INFO
    FORMAT GT
    individual 0
    individual 1
    ...
    individual n - 1; n = number of individuals
    """
    CHROM = contig_id
    ID = '.'
    QUAL = '.'
    FILTER = 'PASS'
    FORMAT = 'GT'
    
    assert ploidy_level == 1 or ploidy_level == 2,\
        f"Specified ploidy_level {ploidy_level} is not recognized."
    
    assert ploidy_level * len(individuals) == len(samples),\
        f"Some individuals may not have the same ploidy level of {ploidy_level}."
    
    # Assume that both sample and individual ids are ordered the same way.
    #individual_id_map = np.repeat(individuals, 2)
    
    header  = "##fileformat=VCFv4.2\n"\
            + "##source=tskit " + tskit.__version__ + "\n"\
            + "##INFO=<ID=AA,Number=1,Type=String,Description=\"Ancestral Allele\">\n"\
            + "##FORMAT=<ID=GT,Number=1,Type=String,Description=\"Genotype\">\n"
    header += "##contig=<ID=" + contig_id + "," + "length=" + str(int(ts.sequence_length)) + ">\n"
    header += "\t".join(['#CHROM', 'POS', 'ID', 'REF', 'ALT', 'QUAL', 'FILTER', 'INFO', 'FORMAT']\
                        + ["s" + str(x) for x in individuals])
    
    with open(out_vcf_file, "w") as vcf:
        vcf.write(header + "\n")
        for i, variant in enumerate(ts.variants()):
            site_id = variant.site.id
            POS = int(np.round(variant.site.position))
            if POS > sequence_length_max:
                break
            # Since the tree sequence was produced using simulation,
            #    there's no reference sequence other than the ancestral sequence.
            REF = variant.site.ancestral_state
            alt_alleles = list(set(variant.alleles) - {REF})
            AA = variant.site.ancestral_state
            ALT = ",".join(alt_alleles) if len(alt_alleles) > 0 else "."
            INFO = "AA" + "=" + AA
            record = [str(x)
                      for x
                      in [CHROM, POS, ID, REF, ALT, QUAL, FILTER, INFO, FORMAT]]
            
            for j in individuals:
                #sample_ids = [samples[x]
                #              for x
                #              in np.where(individual_id_map == j)[0].tolist()]
                #genotype = "|".join([str(variant.genotypes[k])
                #                     for k
                #                     in sample_ids])
                if ploidy_level == 1:
                    genotype = str(variant.genotypes[j])
                else:
                    genotype = str(variant.genotypes[2 * j]) + "|" + str(variant.genotypes[2 * j + 1])
                    
                if mask is not None and mask.query_position(individual = j, position = POS) == True:
                    if ploidy_level == 1:
                        genotype = '.'
                    else:
                        genotype = '.|.' # Or "./."
                record += [genotype]
                
            vcf.write("\t".join(record) + "\n")

In [3]:
# Sourced and modified from:
# https://tsinfer.readthedocs.io/en/latest/tutorial.html#data-example
def get_chromosome_length(vcf):
    assert len(vcf.seqlens) == 1
    return vcf.seqlens[0]


def add_populations(vcf,
                    samples):
    """
    TODO
    """
    pop_ids = [sample_name[0] for sample_name in vcf.samples]
    pop_codes = np.unique(pop_ids)
    pop_lookup = {}
    for p in pop_codes:
        pop_lookup[p] = samples.add_population(metadata = {"name" : p})
    return [pop_lookup[pop_id] for pop_id in pop_ids]


def add_individuals(vcf,
                    samples,
                    ploidy_level,
                    populations):
    for name, population in zip(vcf.samples, populations):
        samples.add_individual(ploidy = ploidy_level,
                               metadata = {"name": name},
                               population = population)


def add_sites(vcf,
              samples,
              ploidy_level,
              warn_monomorphic_sites = False):
    """
    Read the sites in the VCF and add them to the samples object,
    reordering the alleles to put the ancestral allele first,
    if it is available.
    """
    assert ploidy_level == 1 or ploidy_level == 2,\
        f"ploidy_level {ploidy_level} is not recognized."
    
    pos = 0
    for variant in vcf:
        # Check for duplicate site positions.
        if pos == variant.POS:
            raise ValueError("Duplicate positions for variant at position", pos)
        else:
            pos = variant.POS
        # Check that the genotypes are phased.
        #if any([not phased for _, _, phased in variant.genotypes]):
        #    raise ValueError("Unphased genotypes for variant at position", pos)
        alleles = [variant.REF] + variant.ALT # Exactly as in the input VCF file.
        if warn_monomorphic_sites:
            if len(set(alleles) - {'.'}) == 1:
                print(f"Monomorphic site at {pos}")
        ancestral = variant.INFO.get("AA", variant.REF) # Dangerous action!!!
        # Ancestral state must be first in the allele list.
        ordered_alleles = [ancestral] + list(set(alleles) - {ancestral})
        # Create an index mapping from the input VCF to tsinfer input.
        allele_index = {
            old_index: ordered_alleles.index(allele)
            for old_index, allele in enumerate(alleles)
        }
        # When genotype is missing...
        if variant.num_unknown > 0:
            allele_index[-1] = tskit.MISSING_DATA
            ordered_alleles += [None]
        # Map original allele indexes to their indexes in the new alleles list.
        genotypes = [
            allele_index[old_index]
            for row in variant.genotypes # cyvcf2 uses -1 to indicate missing data.
            for old_index in row[0:ploidy_level] # Each is a 3-tuple (allele 1, allele 2, is phased?).
        ]
        samples.add_site(pos,
                         genotypes = genotypes,
                         alleles = ordered_alleles)


def create_sample_data_from_vcf_file(vcf_file):
    vcf = cyvcf2.VCF(vcf_file,
                     gts012 = False, # 0=HOM_REF, 1=HET, 2=UNKNOWN, 3=HOM_ALT
                     strict_gt = True)
    with tsinfer.SampleData(
        sequence_length = get_chromosome_length(vcf)
    ) as samples:
        populations = add_populations(vcf, samples)
        add_individuals(vcf, samples, ploidy_level, populations)
        add_sites(vcf, samples, ploidy_level)
    return(samples)

In [6]:
def get_ts_with_discretized_coordinates(ts):
    ts_tables = ts.dump_tables()
    ts_tables.sites.position = np.round(ts_tables.sites.position)
    ts_tables.deduplicate_sites()
    ts_tables.sort()
    ts_tables.build_index()
    ts_tables.compute_mutation_times()
    ts_discretized = ts_tables.tree_sequence()
    return(ts_discretized)

In [7]:
def impute_genotypes_using_tsinfer(ref_vcf_file,
                                   miss_vcf_file,
                                   imputed_vcf_file,
                                   contig_id):
    sd_ref  = create_sample_data_from_vcf_file(ref_vcf_file)
    sd_miss = create_sample_data_from_vcf_file(miss_vcf_file)
    ad_ref     = tsinfer.generate_ancestors(sample_data = sd_ref)
    # This step is to infer a tree sequence from the sample data.
    ts_anc_ref = tsinfer.match_ancestors(sample_data   = sd_ref,
                                         ancestor_data = ad_ref)
    ts_matched = tsinfer.match_samples(sample_data  = sd_miss,
                                       ancestors_ts = ts_anc_ref)
    with open(imputed_vcf_file, "w") as vcf:
        ts_matched.write_vcf(vcf, contig_id = contig_id)

In [8]:
def impute_genotypes_using_ts_only(ref_vcf_file,
                                   miss_vcf_file,
                                   imputed_vcf_file,
                                   ts_anc_ref,
                                   contig_id):
    sd_ref  = create_sample_data_from_vcf_file(ref_vcf_file)
    sd_miss = create_sample_data_from_vcf_file(miss_vcf_file)
    ts_matched = tsinfer.match_samples(sample_data  = sd_miss,
                                       ancestors_ts = ts_anc_ref)
    with open(imputed_vcf_file, "w") as vcf:
        ts_matched.write_vcf(vcf, contig_id = contig_id)

## Create data sets via simulations.

In [9]:
size_query =  1_000

size_amh   =    100
size_yri   =    250
size_ceu   =  9_500
size_chb   =    250
size_ref   = size_yri + size_ceu + size_chb

print(f"Number of samples from AMH : {size_amh}")
print(f"Number of samples from YRI : {size_yri}")
print(f"Number of samples from CEU : {size_ceu}")
print(f"Number of samples from CHB : {size_chb}")
print(f"Size of reference panel    : {size_ref}")

mutation_rate = 1e-8
recombination_rate = 1e-8

num_replicates = 1

proportion_missing_sites = 0.10
num_missing_sites = 1_000

contig_id = '1'
ploidy_level = 1
sequence_length = 1e7

base_dir = "../data/modern_ooa_unequal_900505_haploid_miss10/"

Number of samples from AMH : 100
Number of samples from YRI : 250
Number of samples from CEU : 9500
Number of samples from CHB : 250
Size of reference panel    : 10000


In [13]:
rate_map = msprime.RateMap.uniform(
    sequence_length = sequence_length,
    rate = mutation_rate
)

In [14]:
yaml_file = "../demes/gutenkunst_ooa_2009.yaml"
ooa_graph = demes.load(yaml_file)
demography_model = msprime.Demography.from_demes(ooa_graph)
demography_model

id,name,description,initial_size,growth_rate,default_sampling_time,extra_metadata
0,ancestral,Equilibrium/root population,0.0,0.0,8800.0,{}
1,AMH,Anatomically modern humans,0.0,0.0,5600.0,{}
2,OOA,Bottleneck out-of-Africa population,0.0,0.0,850.0,{}
3,YRI,"Yoruba in Ibadan, Nigeria",12300.0,0.0,0.0,{}
4,CEU,Utah Residents (CEPH) with Northern and Western European Ancestry,29725.0,0.004,0.0,{}
5,CHB,"Han Chinese in Beijing, China",54090.0,0.0055,0.0,{}

Unnamed: 0,ancestral,AMH,OOA,YRI,CEU,CHB
ancestral,0,0,0,0.0,0.0,0.0
AMH,0,0,0,0.0,0.0,0.0
OOA,0,0,0,0.0,0.0,0.0
YRI,0,0,0,0.0,3e-05,1.9e-05
CEU,0,0,0,3e-05,0.0,9.6e-05
CHB,0,0,0,1.9e-05,9.6e-05,0.0

time,type,parameters,effect
848,Population parameter change,"population=OOA, initial_size=2100",initial_size → 2.1e+03 for population OOA
848,Population Split,"derived=[CEU, CHB], ancestral=OOA","Moves all lineages from derived populations 'CEU' and 'CHB' to the ancestral 'OOA' population. Also set the derived populations to inactive, and all migration rates to and from the derived populations to zero."
848,Migration rate change,"source=CEU, dest=YRI, rate=0",Backwards-time migration rate from CEU to YRI → 0
848,Migration rate change,"source=YRI, dest=CEU, rate=0",Backwards-time migration rate from YRI to CEU → 0
848,Migration rate change,"source=CHB, dest=YRI, rate=0",Backwards-time migration rate from CHB to YRI → 0
848,Migration rate change,"source=YRI, dest=CHB, rate=0",Backwards-time migration rate from YRI to CHB → 0
848,Migration rate change,"source=CHB, dest=CEU, rate=0",Backwards-time migration rate from CHB to CEU → 0
848,Migration rate change,"source=CEU, dest=CHB, rate=0",Backwards-time migration rate from CEU to CHB → 0
848,Migration rate change,"source=OOA, dest=YRI, rate=0.00025",Backwards-time migration rate from OOA to YRI → 0.00025
848,Migration rate change,"source=YRI, dest=OOA, rate=0.00025",Backwards-time migration rate from YRI to OOA → 0.00025


In [17]:
sample_set = [
    msprime.SampleSet(num_samples = size_amh,
                      population = "AMH", # id = 1
                      ploidy = ploidy_level),
    msprime.SampleSet(num_samples = size_query + size_yri,
                      population = "YRI", # id = 3
                      ploidy = ploidy_level),
    msprime.SampleSet(num_samples = size_ceu,
                      population = "CEU", # id = 4
                      ploidy = ploidy_level),
    msprime.SampleSet(num_samples = size_chb,
                      population = "CHB", # id = 5
                      ploidy = ploidy_level)
]

src_ts = [] # List of full simulated ts.

tic = time.time()

print(f"Simulating {num_replicates} ts without duplicate site positions.")
success = 0

while success < num_replicates:
    sim_ts = msprime.sim_mutations(
        msprime.sim_ancestry(
            samples = sample_set,
            demography = demography_model,
            ploidy = ploidy_level,
            model = "hudson",
            recombination_rate = rate_map,
            discrete_genome = True
        ),
        rate = mutation_rate,
        discrete_genome = True
    )
    
    src_ts.append(sim_ts)
    success += 1
    
toc = time.time()
print(f"Simulation of {num_replicates} ts took {round(toc - tic, 2)} seconds.")

Simulating 1 ts without duplicate site positions.
Simulation of 1 ts took 0.99 seconds.


In [18]:
# Impute into YRI samples.
ts = src_ts[0]

individuals_ancient = ts.samples(population = 1)
samples_ancient     = individuals_ancient # When haploid

individuals_query = ts.samples(population = 3)[:size_query]
samples_query     = individuals_query # When haploid

individuals_ref   = np.concatenate([ts.samples(population = 3)[size_query:],
                                    ts.samples(population = 4),
                                    ts.samples(population = 5)])
samples_ref       = individuals_ref # When haploid

gt_mask = mask_genotype.MissingGenotypeMask(individuals         = individuals_query,
                                            sequence_length     = sequence_length,
                                            proportion_missing  = proportion_missing_sites,
                                            num_regions_missing = num_missing_sites,
                                            contig_id           = contig_id)

print(f"Number of ancient samples   : {len(individuals_ancient)}")
print(f"Number of query     samples : {len(individuals_query)}")
print(f"Number of reference samples : {len(individuals_ref)}")

Number of ancient samples   : 100
Number of query     samples : 1000
Number of reference samples : 10000


In [21]:
anc_ts = [] # List of simulated ancestor ts.

for i, ts in enumerate(src_ts):
    print(f"Processing ts {i}.")
    ancient_vcf_file  = base_dir + "ancient/"  + "ancient."  + str(i) + ".vcf"
    ref_vcf_file  = base_dir + "ref/"  + "ref."  + str(i) + ".vcf"
    true_vcf_file = base_dir + "true/" + "true." + str(i) + ".vcf"
    miss_vcf_file = base_dir + "miss/" + "miss." + str(i) + ".vcf"
    ts_anc_ref_file = base_dir + "ts_anc_ref/" + "ts_anc_ref." + str(i) + ".trees"
    
    sd_all = tsinfer.SampleData.from_tree_sequence(ts, use_sites_time = False)
    sd_ancient = sd_all.subset(individuals = individuals_ancient)
    sd_ref     = sd_all.subset(individuals = individuals_ref)
    sd_query   = sd_all.subset(individuals = individuals_query)
    
    sites_to_keep     = find_biallelic_sites(sd_ref, sd_query)
    sd_ref_filtered   =   sd_ref.subset(sites = sites_to_keep[0])
    sd_query_filtered = sd_query.subset(sites = sites_to_keep[1])
    
    # TODO: Refactor.
    #print("Printing ancestors ts.")
    #sim_ts_anc_ref = make_ancestors_ts(samples = samples_ref,
    #                                   ts = ts,
    #                                   remove_leaves = True)
    #tmp_tables = sim_ts_anc_ref.dump_tables()
    #tmp_tables.populations.metadata_schema = tskit.MetadataSchema(schema = None)
    #sim_ts_anc_ref = tmp_tables.tree_sequence()
    #anc_ts.append(sim_ts_anc_ref)
    #sim_ts_anc_ref.dump(ts_anc_ref_file)
    
    print("Printing ancient VCF.")
    print_sample_data_to_vcf(sample_data = sd_ancient,
                             individuals = individuals_ancient,
                             samples = samples_ancient,
                             ploidy_level = ploidy_level,
                             mask = None,
                             out_vcf_file = ancient_vcf_file,
                             contig_id = contig_id,
                             sequence_length_max = 1e24)
    
    #print("Printing reference panel VCF.")
    #print_sample_data_to_vcf(sample_data = sd_ref_filtered,
    #                         individuals = individuals_ref,
    #                         samples = samples_ref,
    #                         ploidy_level = ploidy_level,
    #                         mask = None,
    #                         out_vcf_file = ref_vcf_file,
    #                         contig_id = contig_id,
    #                         sequence_length_max = 1e24)
    
    print("Printing query VCF with non-missing genotypes.")
    print_sample_data_to_vcf(sample_data = sd_query_filtered,
                             individuals = individuals_query,
                             samples = samples_query,
                             ploidy_level = ploidy_level,
                             mask = None,
                             out_vcf_file = true_vcf_file,
                             contig_id = contig_id,
                             sequence_length_max = 1e24)
    
    #print("Printing query VCF with missing genotypes.")
    #print_sample_data_to_vcf(sample_data = sd_query_filtered,
    #                         individuals = individuals_query,
    #                         samples = samples_query,
    #                         ploidy_level = ploidy_level,
    #                         mask = gt_mask,
    #                         out_vcf_file = miss_vcf_file,
    #                         contig_id = contig_id,
    #                         sequence_length_max = 1e24)

Processing ts 0.
Printing ancient VCF.
Printing query VCF with non-missing genotypes.


## Perform genotype imputation.

In [29]:
print("Doing imputation using ts only.")

for i in np.arange(len(src_ts)):
    print(f"Imputing VCF {i}")
    ref_vcf_file     = base_dir + "ref/"  + "ref."  + str(i) + ".vcf"
    miss_vcf_file    = base_dir + "miss/" + "miss." + str(i) + ".vcf"
    imputed_vcf_file = base_dir + "imputed_tsonly/" + "imputed." + str(i) + ".vcf"
    impute_genotypes_using_ts_only(ref_vcf_file = ref_vcf_file,
                                   miss_vcf_file = miss_vcf_file,
                                   imputed_vcf_file = imputed_vcf_file,
                                   ts_anc_ref = anc_ts[i],
                                   contig_id = contig_id)

Doing imputation using ts only.
Imputing VCF 0


In [10]:
print("Doing imputation using tsinfer.")

#for i in np.arange(len(src_ts)):
for i in np.arange(1):
    print(f"Imputing VCF {i}")
    ref_vcf_file     = base_dir + "ref/" + "ref."  + str(i) + ".vcf"
    miss_vcf_file    = base_dir + "miss/" + "miss." + str(i) + ".vcf"
    imputed_vcf_file = base_dir + "imputed_tsinfer/" + "imputed." + str(i) + ".vcf"
    impute_genotypes_using_tsinfer(ref_vcf_file = ref_vcf_file,
                                   miss_vcf_file = miss_vcf_file,
                                   imputed_vcf_file = imputed_vcf_file,
                                   contig_id = contig_id)

Doing imputation using tsinfer.
Imputing VCF 0


In [31]:
print("Doing imputation using BEAGLE.")

beagle_exe = "../analysis/beagle/beagle.28Jun21.220.jar"

for i in np.arange(len(src_ts)):
    print(f"Imputing VCF {i}")
    ref_vcf_file     = base_dir + "ref/"  + "ref."  + str(i) + ".vcf"
    miss_vcf_file    = base_dir + "miss/" + "miss." + str(i) + ".vcf"
    imputed_vcf_file = base_dir + "imputed_beagle/" + "imputed." + str(i)
    beagle_cmd = [
        "java", "-jar", beagle_exe,
        "ref=" + ref_vcf_file,
        "gt="  + miss_vcf_file,
        "out=" + imputed_vcf_file
    ]
    beagle_cmd = " ".join(beagle_cmd)
    print(beagle_cmd + "\n")

Doing imputation using BEAGLE.
Imputing VCF 0
java -jar ../analysis/beagle/beagle.28Jun21.220.jar ref=../data/modern_ooa_unequal_900505_haploid_miss10/ref/ref.0.vcf gt=../data/modern_ooa_unequal_900505_haploid_miss10/miss/miss.0.vcf out=../data/modern_ooa_unequal_900505_haploid_miss10/imputed_beagle/imputed.0

