In [35]:
import math
import os
import random
import sys
import time

import gzip

import numpy as np

import msprime
import tskit
import tsinfer
from tsinfer import make_ancestors_ts

import cyvcf2
import demes
import demesdraw

print(f"tskit {tskit.__version__}")
print(f"tsinfer {tsinfer.__version__}")
print(f"msprime {msprime.__version__}")
print(f"demes {demes.__version__}")
print(f"cyvcf2 {cyvcf2.__version__}")

tskit 0.5.0
tsinfer 0.2.3
msprime 1.1.1
demes 0.2.1
cyvcf2 0.30.14


### Simulate genealogy and genetic variation

In [36]:
sampling_time_query = 100

size_query = 100
size_ref   = 1_000

eff_pop_size = 10_000
mutation_rate = 1e-8
recombination_rate = 1e-8

proportion_missing_sites = 0.20

contig_id = '1'
ploidy_level = 1
sequence_length = 1_000_000

In [37]:
print(f"Size of the reference panel is {size_ref}")
print(f"Size of the query is {size_query}")
print(f"Ploidy level is {ploidy_level}")
print(f"Population size is {eff_pop_size}")
print(f"Sampling time query : {sampling_time_query}")

Size of the reference panel is 1000
Size of the query is 100
Ploidy level is 1
Population size is 10000
Sampling time query : 100


In [38]:
rate_map = msprime.RateMap.uniform(
    sequence_length=sequence_length,
    rate=recombination_rate,
)

In [39]:
sample_set = [
    msprime.SampleSet(num_samples=size_query,
                      time=sampling_time_query,
                      ploidy=ploidy_level),
    msprime.SampleSet(num_samples=size_ref,
                      time=0,
                      ploidy=ploidy_level),
]

In [40]:
ts_full = msprime.sim_mutations(
    msprime.sim_ancestry(
        samples=sample_set,
        population_size=eff_pop_size,
        model="hudson",
        recombination_rate=rate_map,
        discrete_genome=True,
    ),
    rate=mutation_rate,
    discrete_genome=True,
)

In [41]:
individuals_query = np.arange(size_query, dtype=int)
individual_names_query = ["query_" + str(i) for i in individuals_query]
samples_query = np.arange(ploidy_level * size_query, dtype=int)

In [42]:
individuals_ref = np.arange(size_query, size_query + size_ref, dtype=int)
individual_names_ref = ["ref_" + str(i) for i in individuals_ref]
samples_ref = np.arange(ploidy_level * size_query, ploidy_level * (size_query + size_ref), dtype=int)

In [43]:
print("Samples - query")
print(samples_query[:10])
print(samples_query[-10:])
print("\n")
print("Sample - reference panel")
print(samples_ref[:10])
print(samples_ref[-10:])

Samples - query
[0 1 2 3 4 5 6 7 8 9]
[90 91 92 93 94 95 96 97 98 99]


Sample - reference panel
[100 101 102 103 104 105 106 107 108 109]
[1090 1091 1092 1093 1094 1095 1096 1097 1098 1099]


### Match query genomes to an ancestor tree sequence

In [44]:
def create_sample_data_from_tree_sequence(ts, skip_multiallelic_sites=False):
    """
    Match samples in a SampleData object to an ancestors ts,
    without the need to read from a VCF file.
    
    Optionally, skip multi-allelic sites.
    """
    with tsinfer.SampleData(ts.sequence_length) as sample_data:
        for variant in ts.variants():
            # Skip multi-allelic sites
            if skip_multiallelic_sites and len(variant.alleles) > 2:
                continue
            sample_data.add_site(
                position=variant.site.position,
                genotypes=variant.genotypes,
                alleles=variant.alleles,
            )
    return(sample_data)

In [45]:
sd_full = create_sample_data_from_tree_sequence(ts_full, skip_multiallelic_sites=True)

sd_query = sd_full.subset(individuals=individuals_query) # Used for matching

In [46]:
ts_ref = ts_full.simplify(samples_ref, filter_sites=False) # Remove private branches

ts_anc = make_ancestors_ts(ts=ts_ref, remove_leaves=True, samples=None)

tmp_tables = ts_anc.dump_tables()
tmp_tables.individuals.clear()
tmp_tables.populations.metadata_schema = tskit.MetadataSchema(schema=None)

ts_anc = tmp_tables.tree_sequence() # Used for matching

In [47]:
ts_matched = tsinfer.match_samples(sample_data=sd_query, ancestors_ts=ts_anc)
ts_matched

Tree Sequence,Unnamed: 1
Trees,1252
Sequence Length,1000000.0
Time Units,generations
Sample Nodes,100
Total Size,553.6 KiB
Metadata,dict

Table,Rows,Size,Has Metadata
Edges,7021,219.4 KiB,
Individuals,100,3.0 KiB,✅
Migrations,0,8 Bytes,
Mutations,1975,71.4 KiB,
Nodes,1826,49.9 KiB,
Populations,0,8 Bytes,
Provenances,8,5.3 KiB,
Sites,2989,149.7 KiB,✅


### Impute into query genomes from the ancestor tree sequence

In [48]:
def mask_sites_in_sample_data(sd, sequence_length, sites_to_mask=[]):
    """
    Create a SampleData object from an existing SampleData object,
    while masking out specified sites listed in sites_to_mask.
    """
    with tsinfer.SampleData(sequence_length) as sample_data:
        for i, variant in enumerate(sd.variants()):
            if i in sites_to_mask:
                sample_data.add_site(
                    position=variant.site.position,
                    genotypes=np.repeat(tskit.MISSING_DATA, len(variant.genotypes)),
                    alleles=variant.alleles, # Keep the same
                )
            else:
                sample_data.add_site(
                    position=variant.site.position,
                    genotypes=variant.genotypes,
                    alleles=variant.alleles,
                )
    return(sample_data)

In [49]:
masked_sites = random.sample(
    [i for i in range(sd_query.num_sites)],
    int(np.floor(sd_query.num_sites * proportion_missing_sites))
)

print(f"Total sites. : {sd_query.num_sites}")
print(f"Masked sites : {len(masked_sites)}")
print("\n")
print("Masked site positions")
print(masked_sites)

Total sites. : 2989
Masked sites : 597


Masked site positions
[2306, 540, 2727, 1736, 2493, 2425, 46, 320, 784, 1753, 1228, 2184, 2419, 2554, 208, 2689, 1288, 986, 2667, 2662, 1209, 376, 1345, 2628, 1512, 2161, 270, 456, 2068, 837, 1268, 2871, 30, 1824, 6, 987, 1022, 2028, 1008, 1100, 998, 1503, 1365, 794, 1469, 715, 2466, 1883, 1057, 2404, 1023, 720, 2661, 813, 2222, 2892, 407, 765, 2730, 1474, 84, 996, 455, 2562, 2017, 2597, 2531, 2265, 93, 685, 701, 2563, 1482, 1591, 595, 2234, 2054, 1852, 1740, 2783, 650, 1836, 1357, 25, 2913, 2802, 2434, 1047, 2513, 2716, 2270, 888, 2449, 2492, 104, 2206, 2176, 143, 29, 307, 1928, 2283, 1440, 2067, 1617, 2584, 1419, 553, 1494, 2719, 683, 1150, 285, 1321, 103, 2432, 157, 564, 1152, 2353, 2527, 1245, 371, 2617, 1040, 703, 487, 468, 415, 163, 2146, 2001, 2415, 2747, 1354, 1128, 1460, 1620, 593, 749, 102, 1889, 449, 485, 512, 2064, 846, 746, 887, 1234, 2820, 1682, 1250, 2246, 1858, 890, 583, 1898, 950, 2383, 2973, 1256, 2915, 1275, 2130, 1594, 2437, 

In [50]:
sd_query_masked = mask_sites_in_sample_data(
    sd_query,
    sequence_length=sequence_length,
    sites_to_mask=masked_sites
)

In [57]:
ts_imputed = tsinfer.match_samples(
    sample_data=sd_query_masked,
    ancestors_ts=ts_anc
)
ts_imputed

Tree Sequence,Unnamed: 1
Trees,1248
Sequence Length,1000000.0
Time Units,generations
Sample Nodes,100
Total Size,553.9 KiB
Metadata,dict

Table,Rows,Size,Has Metadata
Edges,7049,220.3 KiB,
Individuals,100,3.0 KiB,✅
Migrations,0,8 Bytes,
Mutations,1942,70.2 KiB,
Nodes,1844,50.4 KiB,
Populations,0,8 Bytes,
Provenances,8,5.3 KiB,
Sites,2989,149.6 KiB,✅


### Compare the true and imputed tree sequences

In [52]:
ts_matched

Tree Sequence,Unnamed: 1
Trees,1252
Sequence Length,1000000.0
Time Units,generations
Sample Nodes,100
Total Size,553.6 KiB
Metadata,dict

Table,Rows,Size,Has Metadata
Edges,7021,219.4 KiB,
Individuals,100,3.0 KiB,✅
Migrations,0,8 Bytes,
Mutations,1975,71.4 KiB,
Nodes,1826,49.9 KiB,
Populations,0,8 Bytes,
Provenances,8,5.3 KiB,
Sites,2989,149.7 KiB,✅


In [58]:
ts_imputed

Tree Sequence,Unnamed: 1
Trees,1248
Sequence Length,1000000.0
Time Units,generations
Sample Nodes,100
Total Size,553.9 KiB
Metadata,dict

Table,Rows,Size,Has Metadata
Edges,7049,220.3 KiB,
Individuals,100,3.0 KiB,✅
Migrations,0,8 Bytes,
Mutations,1942,70.2 KiB,
Nodes,1844,50.4 KiB,
Populations,0,8 Bytes,
Provenances,8,5.3 KiB,
Sites,2989,149.6 KiB,✅


### Write results to VCF

In [59]:
true_vcf_file = "test.true.vcf.gz"
imputed_vcf_file = "test.imputed.vcf.gz"

In [60]:
with gzip.open(true_vcf_file, "wt") as f:
    ts_matched.write_vcf(f)

In [61]:
with gzip.open(imputed_vcf_file, "wt") as f:
    ts_imputed.write_vcf(f)