In [1]:
import os
import gzip

import numpy as np

import msprime
import tskit
import tsinfer
from tsinfer import make_ancestors_ts

print(f"msprime {msprime.__version__}")
print(f"tskit {tskit.__version__}")
print(f"tsinfer {tsinfer.__version__}")

msprime 1.1.1
tskit 0.5.2.dev0
tsinfer 0.2.3


### Helper functions

In [2]:
def count_sites_by_type(ts_or_sd):
    """
    Iterate through the variants of a TreeSequence or SampleData object,
    and count the number of mono-, bi-, tri-, and quad-allelic sites.
    
    :param TreeSequence/SampleData ts_or_sd:
    """
    assert isinstance(ts_or_sd, tskit.TreeSequence) or\
        isinstance(ts_or_sd, tsinfer.SampleData)
    
    sites_mono = 0
    sites_bi = 0
    sites_bi_singleton = 0
    sites_tri = 0
    sites_quad = 0
    
    for v in ts_or_sd.variants():
        num_alleles = len(set(v.alleles) - {None})
        if num_alleles == 1:
            sites_mono += 1
        elif num_alleles == 2:
            sites_bi += 1
            if np.sum(v.genotypes) == 1:
                sites_bi_singleton += 1
        elif num_alleles == 3:
            sites_tri += 1
        else:
            sites_quad += 1
    
    sites_total = sites_mono + sites_bi + sites_tri + sites_quad
    
    print(f"\tsites mono : {sites_mono}")
    print(f"\tsites bi   : {sites_bi} ({sites_bi_singleton} singletons)")
    print(f"\tsites tri  : {sites_tri}")
    print(f"\tsites quad : {sites_quad}")
    print(f"\tsites total: {sites_total}")

In [3]:
def check_site_positions_ts_issubset_sd(tree_sequence, sample_data):
    """
    Check whether the site positions in `TreeSequence` are a subset of
    the site positions in `SampleData`.
    
    :param TreeSequence tree_sequence:
    :param SampleData sample_data:
    :return bool:
    """
    ts_site_positions = np.empty(tree_sequence.num_sites)
    sd_site_positions = np.empty(sample_data.num_sites)
    
    i = 0
    for v in tree_sequence.variants():
        ts_site_positions[i] = v.site.position
        i += 1
        
    j = 0
    for v in sample_data.variants():
        sd_site_positions[j] = v.site.position
        j += 1
        
    assert i == tree_sequence.num_sites
    assert j == sample_data.num_sites
    
    if set(ts_site_positions).issubset(set(sd_site_positions)):
        return True
    else:
        return False

In [4]:
def compare_sites_sd_and_ts(
    sample_data,
    tree_sequence,
    is_common,
    check_matching_ancestral_state=True
):
    """
    If `include` is set to True, then get the ids of the sites
    found in `sample_data` AND in `tree_sequence`.
    
    if `exclude` is set to False, then get the ids of the sites
    found in `sample_data` but NOT in `tree_sequence`.
    
    :param TreeSequence tree_sequence:
    :param SampleData sample_data:
    :param include bool:
    :return np.array:
    """
    ts_site_positions = np.empty(tree_sequence.num_sites)
    
    i = 0
    for v in tree_sequence.variants():
        ts_site_positions[i] = v.site.position
        i += 1
        
    assert i == tree_sequence.num_sites
    
    sd_site_id = []
    for v in sample_data.variants():
        if is_common:
            if v.site.position in ts_site_positions:
                sd_site_id.append(v.site.id)
                if check_matching_ancestral_state:
                    ts_site = tree_sequence.site(position=v.site.position)
                    assert v.site.ancestral_state == ts_site.ancestral_state,\
                        f"Ancestral states at position {v.site.position} not identical, " +\
                        f"{v.site.ancestral_state} vs. {ts_site.ancestral_state}."
        else:
            if v.site.position not in ts_site_positions:
                sd_site_id.append(v.site.id)
    
    return(np.array(sd_site_id))

In [5]:
def make_compatible_sample_data(sample_data, ancestors_ts):
    """
    Make an editable copy of a `sample_data` object, and edit it so that:
    (1) the derived alleles in the `sample_data` object not found in `ancestors_ts` are marked as MISSING;
    (2) the allele list in the new `sample_data` corresponds to the allele list in `ancestors_ts`.
    
    N.B. Two `SampleData` attributes `sites_alleles` and `sites_genotypes`,
    which are not explained in the tsinfer API doc, are used to facilitate the editing.
    
    :param SampleData sample_data:
    :param TreeSequence ancestors_ts:
    :return SampleData:
    """
    new_sample_data = sample_data.copy()
    
    # Iterate through the sites in `ancestors_ts` using one generator,
    # while iterating through the sites in `sample_data` using another generator,
    # letting the latter generator catch up.
    sd_variants = sample_data.variants()
    sd_var = next(sd_variants)
    for ts_site in ancestors_ts.sites():
        while sd_var.site.position != ts_site.position:
            # Sites in `samples_data` but not in `ancestors_ts` are not imputed.
            # Also, leave them as is in the `sample_data`, but keep track of them.
            sd_var = next(sd_variants)
            
        sd_site_id = sd_var.site.id # Site id in `sample_data`
        
        # CHECK that all the sites in `ancestors_ts` are biallelic.
        assert len(ts_site.alleles) == 2
        
        # Get the derived allele in `ancestors_ts` in nucleotide space
        ts_ancestral_allele = ts_site.ancestral_state
        ts_derived_allele = ts_site.alleles - {ts_ancestral_allele}
        assert len(ts_derived_allele) == 1 # CHECK
        ts_derived_allele = tuple(ts_derived_allele)[0]
        
        # CHECK that the ancestral allele should be the same
        # in both `ancestors_ts` and `sample_data`.
        assert ts_ancestral_allele == sd_var.alleles[0]
        
        if ts_derived_allele not in sd_var.alleles:
            # Case 1:
            # If the derived alleles in the `sample_data` are not in `ancestors_ts`,
            # then mark them as missing.
            #
            # The site in `sample_data` may be mono-, bi-, or multiallelic.
            #
            # We cannot determine whether the extra derived alleles in `sample_data`
            # are derived from 0 or 1 in `ancestors_ts` anyway.
            new_sample_data.sites_genotypes[sd_site_id] = np.where(
                sd_var.genotypes != 0, # Keep if ancestral
                tskit.MISSING_DATA, # Otherwise, flag as missing
                0,
            )
            print(f"Site {sd_site_id} has no matching derived alleles in the query samples.")
            # Update allele list
            new_sample_data.sites_alleles[sd_site_id] = [ts_ancestral_allele]
        else:
            # The allele lists in `ancestors_ts` and `sample_data` may be different.
            ts_derived_allele_index = sd_var.alleles.index(ts_derived_allele)
            
            if ts_derived_allele_index == 1:
                # Case 2:
                # Both the ancestral and derived alleles correspond exactly.
                if len(sd_var.alleles) == 2:
                    continue
                # Case 3:
                # The derived allele in `ancestors_ts` is indexed as 1 in `sample_data`,
                # so mark alleles >= 2 as missing.
                new_sample_data.sites_genotypes[sd_site_id] = np.where(
                    sd_var.genotypes > 1, # 0 and 1 should be kept "as is"
                    tskit.MISSING_DATA, # Otherwise, flag as missing
                    sd_var.genotypes,
                )
                print(f"Site {sd_site_id} has extra derived allele(s) in the query samples (set as missing).")
            else:
                # Case 4:
                #   The derived allele in `ancestors_ts` is NOT indexed as 1 in `sample_data`,
                #   so the alleles in `sample_data` needs to be reordered,
                #   such that the 1-indexed allele is also indexed as 1 in `ancestors_ts`.
                new_sample_data.sites_genotypes[sd_site_id] = np.where(
                    sd_var.genotypes == 0,
                    0, # Leave ancestral allele "as is"
                    np.where(
                        sd_var.genotypes == ts_derived_allele_index,
                        1, # Change it to 1 so that it corresponds to `ancestors_ts`
                        tskit.MISSING_DATA, # Otherwise, mark as missing
                    ),
                )
                print(f"Site {sd_site_id} has the target derived allele at a different index.")
            # Update allele list
            new_sample_data.sites_alleles[sd_site_id] = [ts_ancestral_allele, ts_derived_allele]
            
    new_sample_data.finalise()
    
    return(new_sample_data)

In [32]:
def pick_masked_sites_random(site_ids, prop_masked_sites):
    """
    Draw N sites from `sites_ids` at random, where N is the number of sites to mask
    based on a specified proportion of masked sites `prop_masked_sites`.
    
    TODO: Specify random seed.
    
    :param np.array site_ids:
    :param float prop_masked_sites: float between 0 and 1
    :return np.array: list of site ids
    """
    assert prop_masked_sites >= 0
    assert prop_masked_sites <= 1
    
    rng = np.random.default_rng()
    
    num_masked_sites = int(np.floor(len(site_ids) * prop_masked_sites))
    
    masked_site_ids = np.sort(
        rng.choice(
            site_ids,
            num_masked_sites,
            replace=False,
        )
    )
    
    return(masked_site_ids)

In [7]:
def mask_sites_in_sample_data(sample_data, masked_sites=None):
    """
    Create and return a `SampleData` object from an existing `SampleData` object,
    which contain masked sites as listed in `masked_sites` (site ids).
    
    :param SampleData sample_data:
    :param np.array masked_sites: list of site ids (NOT positions)
    :return SampleData:
    """
    new_sample_data = sample_data.copy()
    
    for v in sample_data.variants():
        if v.site.id in masked_sites:
            new_sample_data.sites_genotypes[v.site.id] = np.full_like(v.genotypes, tskit.MISSING_DATA)
    
    new_sample_data.finalise()
    
    return(new_sample_data)

### Simulation parameters

In [40]:
#seed = 123

replicate_index = 0

sampling_time_query = 0

proportion_missing_sites = 0.80

In [9]:
# For simulations
# size_ref   = 1e4
# size_query = 1e3

# eff_pop_size = 10_000
# mutation_rate = 1e-8
# recombination_rate = 1e-8

# contig_id = '1'
# ploidy_level = 1
# sequence_length = 1_000_000

In [10]:
# For testing
size_ref   = 50
size_query = 50

eff_pop_size = 10_000
mutation_rate = 1e-6
recombination_rate = 1e-7

contig_id = '1'
ploidy_level = 1
sequence_length = 10_000

### Simulate genealogy and genetic variation

In [11]:
# Uniform recombination rate
recomb_rate_map = msprime.RateMap.uniform(
    sequence_length=sequence_length,
    rate=recombination_rate,
)

# Uniform mutation rate
mut_rate_map = msprime.RateMap.uniform(
    sequence_length=sequence_length,
    rate=mutation_rate,
)

In [12]:
sample_set = [
    # Reference genomes
    msprime.SampleSet(num_samples=size_ref,
                      time=0,
                      ploidy=ploidy_level),
    # Query genomes
    msprime.SampleSet(num_samples=size_query,
                      time=sampling_time_query,
                      ploidy=ploidy_level),
]

In [14]:
# A simulated tree sequence does not contain any monoallelic sites,
# but there may be multiallelic sites.
ts_full = msprime.sim_mutations(
    msprime.sim_ancestry(
        samples=sample_set,
        population_size=eff_pop_size,
        model="hudson",
        recombination_rate=recomb_rate_map,
        discrete_genome=True,
        #random_seed=seed,
    ),
    rate=mut_rate_map,
    discrete_genome=True,
    #random_seed=seed,
)

# Remove populations
tables = ts_full.dump_tables()
tables.populations.clear()
tables.nodes.population = np.full_like(tables.nodes.population, tskit.NULL)
ts_full = tables.tree_sequence()

print("TS full")
count_sites_by_type(ts_full)

TS full
	sites mono : 0
	sites bi   : 1404 (428 singletons)
	sites tri  : 84
	sites quad : 1
	sites total: 1489


In [15]:
# The first `size_ref` individuals or `ploidy_level` * `size_ref` samples are taken as the reference panel.
# The remaining individuals and samples are the query/target to impute into.
individuals_ref = np.arange(size_ref, dtype=int)
samples_ref = np.arange(ploidy_level * size_ref, dtype=int)

individuals_query = np.arange(size_ref, size_ref + size_query, dtype=int)
samples_query = np.arange(ploidy_level * size_ref, ploidy_level * (size_ref + size_query), dtype=int)

### Create an ancestor ts from the reference genomes

In [16]:
# Remove all the branches leading to the query genomes
ts_ref = ts_full.simplify(samples_ref, filter_sites=False)

print(f"TS ref has {ts_ref.num_samples} sample genomes ({ts_ref.sequence_length} bp)")
print(f"TS ref has {ts_ref.num_sites} sites and {ts_ref.num_trees} trees")
print("TS ref")
count_sites_by_type(ts_ref)

TS ref has 50 sample genomes (10000.0 bp)
TS ref has 1489 sites and 87 trees
TS ref
	sites mono : 304
	sites bi   : 1132 (396 singletons)
	sites tri  : 52
	sites quad : 1
	sites total: 1489


In [17]:
# Multiallelic sites are automatically removed when generating an ancestor ts.
# Sites which are biallelic in the full sample set but monoallelic in the ref. sample set are removed.
# So, only biallelic sites are retained in the ancestor ts.
ts_anc = make_ancestors_ts(ts=ts_ref, remove_leaves=True, samples=None)

print(f"TS anc has {ts_anc.num_samples} sample genomes ({ts_anc.sequence_length} bp)")
print(f"TS anc has {ts_anc.num_sites} sites and {ts_anc.num_trees} trees")
print("TS anc")
count_sites_by_type(ts_anc)

TS anc has 121 sample genomes (10000.0 bp)
TS anc has 714 sites and 76 trees
TS anc
	sites mono : 0
	sites bi   : 714 (0 singletons)
	sites tri  : 0
	sites quad : 0
	sites total: 714


### Create a SampleData object holding the query genomes 

In [18]:
sd_full = tsinfer.SampleData.from_tree_sequence(ts_full)
sd_query = sd_full.subset(individuals_query)

print(f"SD query has {sd_query.num_samples} sample genomes ({sd_query.sequence_length} bp)")
print(f"SD query has {sd_query.num_sites} sites")
print("SD query")
count_sites_by_type(sd_query)

SD query has 50 sample genomes (10000.0 bp)
SD query has 1489 sites
SD query
	sites mono : 0
	sites bi   : 1404 (487 singletons)
	sites tri  : 84
	sites quad : 1
	sites total: 1489


In [19]:
assert check_site_positions_ts_issubset_sd(ts_anc, sd_query)

In [20]:
new_sd_query = make_compatible_sample_data(
    sample_data=sd_query,
    ancestors_ts=ts_anc,
)

Site 49 has extra derived allele(s) in the query samples (set as missing).
Site 56 has extra derived allele(s) in the query samples (set as missing).
Site 87 has extra derived allele(s) in the query samples (set as missing).
Site 445 has extra derived allele(s) in the query samples (set as missing).
Site 676 has extra derived allele(s) in the query samples (set as missing).
Site 791 has extra derived allele(s) in the query samples (set as missing).
Site 880 has the target derived allele at a different index.
Site 894 has extra derived allele(s) in the query samples (set as missing).
Site 896 has extra derived allele(s) in the query samples (set as missing).
Site 941 has extra derived allele(s) in the query samples (set as missing).
Site 1034 has the target derived allele at a different index.
Site 1044 has the target derived allele at a different index.
Site 1161 has extra derived allele(s) in the query samples (set as missing).
Site 1219 has extra derived allele(s) in the query sample



In [21]:
# # Check that matching works
# ts_matched = tsinfer.match_samples(
#     sample_data=new_sd_query,
#     ancestors_ts=ts_anc,
# )

### Create a SampleData object with masked sites

In [31]:
# Identify sites in both `sd_query` and `ts_anc`.
shared_sites = compare_sites_sd_and_ts(sd_query, ts_anc, is_common=True)
print(f"Shared sites: {len(shared_sites)}")

# Identify sites in `sd_query` but not in `ts_anc`, which are not to be imputed.
exclude_sites = compare_sites_sd_and_ts(sd_query, ts_anc, is_common=False)
print(f"Exclude sites: {len(exclude_sites)}")

assert len(set(shared_sites).intersection(set(exclude_sites))) == 0

Shared sites: 714
Exclude sites: 775


In [44]:
masked_site_ids = pick_masked_sites_random(
    site_ids=shared_sites,
    prop_masked_sites=proportion_missing_sites,
)

In [45]:
sd_query_masked = mask_sites_in_sample_data(
    new_sd_query,
    masked_sites=masked_site_ids,
)



### Impute the query genomes

In [68]:
ts_imp = tsinfer.match_samples(
    sample_data=sd_query_masked,
    ancestors_ts=ts_anc,
)

### Evaluate imputation performance

In [90]:
assert ts_ref.num_sites == ts_imp.num_sites

ts_ref_site_pos = [site.position for site in ts_ref.sites()]
ts_imp_site_pos = [site.position for site in ts_imp.sites()]

assert len(set(ts_ref_site_pos).intersection(set(ts_imp_site_pos))) == len(ts_ref_site_pos)

for v_ref, v_imp in zip(ts_ref.variants(), ts_imp.variants()):
    if v_imp.site.id in shared_sites:
        assert v_ref.alleles[0] == v_imp.alleles[0]
        if len(v_imp.alleles) == 1:
            # Monoallelic sites in `sd_query` cannot be imputed
            continue
        assert v_ref.num_alleles == 2
        assert v_imp.num_alleles == 2
        freqs_ref = v_ref.frequencies()
        freqs_imp = v_imp.frequencies()
        # Note: A minor allele in `ts_ref` may be a major allele in `sd_query`
        af_0 = freqs_ref[v_ref.alleles[0]]
        af_1 = freqs_ref[v_ref.alleles[1]]
        # Get MAF from `ts_ref`
        if af_1 < af_0:
            minor_allele_index = 1
            maf = af_1
        assert not np.any(v_imp.genotypes == -1)
        total_concordance = np.sum(v_ref.genotypes == v_imp.genotypes) / len(v_ref.genotypes)
        print(v_ref.site.position, maf, total_concordance)

5.0 0.12 0.84
7.0 0.12 0.84
29.0 0.1 0.84
30.0 0.1 0.84
40.0 0.12 0.84
44.0 0.12 0.84
52.0 0.1 0.76
57.0 0.12 0.84
61.0 0.12 0.84
75.0 0.12 0.84
76.0 0.12 0.84
80.0 0.12 0.84
89.0 0.14 0.78
100.0 0.12 0.84
105.0 0.36 0.5
118.0 0.36 0.66
119.0 0.16 0.74
121.0 0.06 0.84
140.0 0.12 0.84
142.0 0.12 0.66
145.0 0.04 0.94
146.0 0.14 0.78
149.0 0.04 0.88
158.0 0.12 0.86
161.0 0.04 0.92
165.0 0.36 0.54
172.0 0.1 0.84
177.0 0.1 0.84
188.0 0.04 0.86
190.0 0.12 0.86
210.0 0.1 0.88
213.0 0.04 0.94
218.0 0.36 0.58
229.0 0.14 0.78
235.0 0.36 0.54
237.0 0.36 0.68
239.0 0.12 0.82
252.0 0.12 0.86
254.0 0.12 0.86
255.0 0.12 0.68
265.0 0.28 0.68
270.0 0.12 0.86
282.0 0.04 0.92
288.0 0.12 0.82
290.0 0.18 0.72
298.0 0.04 0.92
299.0 0.14 0.78
302.0 0.14 0.68
313.0 0.04 0.92
315.0 0.36 0.54
317.0 0.04 0.94
321.0 0.36 0.54
327.0 0.04 0.94
330.0 0.12 0.82
345.0 0.24 0.68
358.0 0.24 0.68
364.0 0.04 0.94
370.0 0.36 0.54
389.0 0.12 0.82
391.0 0.12 0.82
413.0 0.04 0.86
435.0 0.1 0.88
473.0 0.1 0.76
488.0 0.08 0.86


### Write VCF

In [15]:
# ref_vcf_file = ".".join(["ref", str(replicate_index), "vcf", "gz"])

# query_true_vcf_file = ".".join(["query_true", str(replicate_index), "vcf", "gz"])
# query_miss_vcf_file = ".".join(["query_miss", str(replicate_index), "vcf", "gz"])

# imputed_vcf_file = ".".join(["imputed", str(replicate_index), "vcf", "gz"])

In [16]:
# with gzip.open(ref_vcf_file, "wt") as f:
#     ts_ref.write_vcf(f)

In [18]:
# TODO:
#   Write VCF from SD query
#   Write VCF from SD query masked

In [19]:
# with gzip.open(imputed_vcf_file, "wt") as f:
#     ts_imputed.write_vcf(f)