In [1]:
import math
import os
import random
import sys
import time

import gzip

import numpy as np

import msprime
import tskit
import tsinfer
from tsinfer import make_ancestors_ts

import cyvcf2
import demes
import demesdraw

print(f"tskit {tskit.__version__}")
print(f"tsinfer {tsinfer.__version__}")
print(f"msprime {msprime.__version__}")
print(f"demes {demes.__version__}")
print(f"cyvcf2 {cyvcf2.__version__}")

tskit 0.5.0
tsinfer 0.2.3
msprime 1.1.1
demes 0.2.1
cyvcf2 0.30.14


### Simulate genealogy and genetic variation

In [9]:
sampling_time_query = 100

size_query = 10
size_ref   = 100

eff_pop_size = 10_000
mutation_rate = 1e-6
recombination_rate = 1e-6

proportion_missing_sites = 0.80

contig_id = '1'
ploidy_level = 1
sequence_length = 1_000

print(f"Size of the reference panel is {size_ref}")
print(f"Size of the query is {size_query}")
print(f"Ploidy level is {ploidy_level}")
print(f"Population size is {eff_pop_size}")
print(f"Sampling time query : {sampling_time_query}")

Size of the reference panel is 100
Size of the query is 10
Ploidy level is 1
Population size is 10000
Sampling time query : 100


In [10]:
rate_map = msprime.RateMap.uniform(
    sequence_length=sequence_length,
    rate=recombination_rate,
)

In [11]:
sample_set = [
    msprime.SampleSet(num_samples=size_query,
                      time=sampling_time_query,
                      ploidy=ploidy_level),
    msprime.SampleSet(num_samples=size_ref,
                      time=0,
                      ploidy=ploidy_level),
]

In [15]:
ts_full = msprime.sim_mutations(
    msprime.sim_ancestry(
        samples=sample_set,
        population_size=eff_pop_size,
        model="hudson",
        recombination_rate=rate_map,
        discrete_genome=True
    ),
    rate=mutation_rate,
    discrete_genome=True,
)

In [35]:
individuals_query = np.arange(size_query, dtype=int)
individual_names_query = ["query_" + str(i) for i in individuals_query]
samples_query = np.arange(ploidy_level * size_query, dtype=int)

In [36]:
individuals_ref = np.arange(size_query, size_query + size_ref, dtype=int)
individual_names_ref = ["ref_" + str(i) for i in individuals_ref]
samples_ref = np.arange(ploidy_level * size_query, ploidy_level * (size_query + size_ref), dtype=int)

In [135]:
print("Samples - query")
print(samples_query)
print("\n")
print("Sample - reference panel")
print(samples_ref)

Samples - query
[0 1 2 3 4 5 6 7 8 9]


Sample - reference panel
[ 10  11  12  13  14  15  16  17  18  19  20  21  22  23  24  25  26  27
  28  29  30  31  32  33  34  35  36  37  38  39  40  41  42  43  44  45
  46  47  48  49  50  51  52  53  54  55  56  57  58  59  60  61  62  63
  64  65  66  67  68  69  70  71  72  73  74  75  76  77  78  79  80  81
  82  83  84  85  86  87  88  89  90  91  92  93  94  95  96  97  98  99
 100 101 102 103 104 105 106 107 108 109]


### Match query genomes to an ancestor tree sequence

In [137]:
def create_sample_data_from_tree_sequence(ts):
    """
    Match samples in a SampleData object to an ancestors ts,
    without the need to read from a VCF file.
    """
    with tsinfer.SampleData(ts.sequence_length) as sample_data:
        for variant in ts.variants():
            # Skip multi-allelic sites
            if(len(variant.alleles) > 2):
                continue
            sample_data.add_site(
                position=variant.site.position,
                genotypes=variant.genotypes,
                alleles=variant.alleles,
            )
    return(sample_data)

In [142]:
sd_full = create_sample_data_from_tree_sequence(ts_full) # Remove multi-allelic sites first

sd_query = sd_full.subset(individuals=individuals_query) # Used for matching

In [143]:
ts_ref = ts_full.simplify(samples_ref, filter_sites=False) # Remove private branches.

ts_anc = make_ancestors_ts(ts=ts_ref, remove_leaves=True, samples=None)

tmp_tables = ts_anc.dump_tables()
tmp_tables.individuals.clear()
tmp_tables.populations.metadata_schema = tskit.MetadataSchema(schema=None)

ts_anc = tmp_tables.tree_sequence() # Used for matching

In [144]:
ts_matched = tsinfer.match_samples(sample_data=sd_query, ancestors_ts=ts_anc)
ts_matched

Tree Sequence,Unnamed: 1
Trees,83
Sequence Length,1000.0
Time Units,generations
Sample Nodes,10
Total Size,36.2 KiB
Metadata,dict

Table,Rows,Size,Has Metadata
Edges,367,11.5 KiB,
Individuals,10,324 Bytes,✅
Migrations,0,8 Bytes,
Mutations,90,3.3 KiB,
Nodes,138,3.8 KiB,
Populations,0,8 Bytes,
Provenances,8,5.2 KiB,
Sites,181,9.1 KiB,✅


### Impute into query genomes from the ancestor tree sequence

In [148]:
def mask_sites_in_sample_data(sd, sequence_length, sites_to_mask=None):
    """
    TODO
    """
    with tsinfer.SampleData(sequence_length) as sample_data:
        for i, variant in enumerate(sd.variants()):
            if sites_to_mask is not None and i in sites_to_mask:
                sample_data.add_site(
                    position=variant.site.position,
                    genotypes=np.repeat(tskit.MISSING_DATA, len(variant.genotypes)),
                    alleles=variant.alleles, # Keep the same
                )
            else:
                sample_data.add_site(
                    position=variant.site.position,
                    genotypes=variant.genotypes,
                    alleles=variant.alleles,
                )
    return(sample_data)

In [154]:
masked_sites = random.sample(
    [i for i in range(ts_query.num_sites)],
    ts_query.num_sites
)

sd_query_masked = mask_sites_in_sample_data(
    sd_query,
    sequence_length=sequence_length,
    sites_to_mask=masked_sites
)

In [157]:
ts_masked_matched = tsinfer.match_samples(sample_data=sd_query_masked, ancestors_ts=ts_anc)
ts_masked_matched

Tree Sequence,Unnamed: 1
Trees,66
Sequence Length,1000.0
Time Units,generations
Sample Nodes,10
Total Size,24.7 KiB
Metadata,dict

Table,Rows,Size,Has Metadata
Edges,169,5.3 KiB,
Individuals,10,324 Bytes,✅
Migrations,0,8 Bytes,
Mutations,36,1.3 KiB,
Nodes,75,2.1 KiB,
Populations,0,8 Bytes,
Provenances,8,5.2 KiB,
Sites,181,9.0 KiB,✅


### Compare the true and imputed tree sequences

In [161]:
ts_matched

Tree Sequence,Unnamed: 1
Trees,83
Sequence Length,1000.0
Time Units,generations
Sample Nodes,10
Total Size,36.2 KiB
Metadata,dict

Table,Rows,Size,Has Metadata
Edges,367,11.5 KiB,
Individuals,10,324 Bytes,✅
Migrations,0,8 Bytes,
Mutations,90,3.3 KiB,
Nodes,138,3.8 KiB,
Populations,0,8 Bytes,
Provenances,8,5.2 KiB,
Sites,181,9.1 KiB,✅


In [162]:
ts_masked_matched

Tree Sequence,Unnamed: 1
Trees,66
Sequence Length,1000.0
Time Units,generations
Sample Nodes,10
Total Size,24.7 KiB
Metadata,dict

Table,Rows,Size,Has Metadata
Edges,169,5.3 KiB,
Individuals,10,324 Bytes,✅
Migrations,0,8 Bytes,
Mutations,36,1.3 KiB,
Nodes,75,2.1 KiB,
Populations,0,8 Bytes,
Provenances,8,5.2 KiB,
Sites,181,9.0 KiB,✅


### Write results to VCF

In [158]:
true_vcf_file = "test.true.vcf.gz"
imputed_vcf_file = "test.imputed.vcf.gz"

In [159]:
with gzip.open(true_vcf_file, "wt") as f:
    ts_matched.write_vcf(f)

In [160]:
with gzip.open(imputed_vcf_file, "wt") as f:
    ts_masked_matched.write_vcf(f)