In [None]:
import numpy as np
import tskit
import tsinfer
import msprime


In [None]:
import sys
sys.path.append("../src")
import masks
import measures
import util
import simulate_ts


In [None]:
# Population-matched imputation
num_ref_inds = 1_500
num_query_inds = 500
ts_full = simulate_ts.get_ts_ten_pop(
    num_ref_inds=num_ref_inds,
    num_query_inds=num_query_inds,
    sequence_length=1e7,    # 10 Mbp
    pop_ref='CEU',
    pop_query='CEU',
)
ts_full


In [None]:
# Prepare files for tsimpute
prefix = "jacobs_ceu_ceu_2k"
ts_full_file = prefix + ".full.trees"
ts_ref_file = prefix + ".ref.trees"
ts_query_file = prefix + ".query.trees"
npy_query_file = prefix + ".query.npy"


In [None]:
ploidy = 2
num_ref_haps = ploidy * num_ref_inds
num_query_haps = ploidy * num_query_inds
idx_ref_inds = np.arange(num_ref_inds)
idx_ref_haps = np.arange(num_ref_haps)
idx_query_inds = np.arange(num_ref_inds, num_ref_inds + num_query_inds)
idx_query_haps = np.arange(num_ref_haps, num_ref_haps + num_query_haps)
assert np.all(ts_full.nodes_flags[:(num_ref_haps + num_query_haps)] == 1)
assert np.all(ts_full.nodes_flags[(num_ref_haps + num_query_haps):] == 0)
assert np.all(ts_full.nodes_flags[idx_ref_haps] == 1)
assert np.all(ts_full.nodes_flags[idx_query_haps] == 1)


In [None]:
# Simplify down to reference haplotypes, removing monoallelic sites.
ts_ref = ts_full.simplify(idx_ref_haps, filter_sites=True)
ts_ref


In [None]:
# Identify and remove sites with private mutations.
af = np.zeros(ts_ref.num_sites, dtype=np.int32)
i = 0
for v in ts_ref.variants():
    af[i] = min(v.counts().values())
    i += 1
sites_private_mutation = np.where(af < 2)[0]
print(f"Sites with private mutation: {len(sites_private_mutation)}")
ts_ref_filtered = ts_ref.delete_sites(site_ids=sites_private_mutation)
ts_ref_filtered


In [None]:
# Identify sites with high MAF.
maf = np.zeros(ts_ref_filtered.num_sites, dtype=np.float64)
i = 0
for v in ts_ref_filtered.variants():
    maf[i] = min(v.frequencies().values())
    i += 1
sites_high_maf = np.where(maf >= 0.05)[0]
print(f"Sites with high MAF: {len(sites_high_maf)}")


In [None]:
# Randomly select genotyped markers
reference_markers = np.arange(ts_ref_filtered.num_sites)
num_markers = 3333 # Density of 3,333 markers per 10 Mb
genotyped_markers = np.random.choice(sites_high_maf, size=num_markers, replace=False)
genotyped_markers.sort()    # In-place sort
ungenotyped_markers = np.setdiff1d(reference_markers, genotyped_markers)
assert np.union1d(genotyped_markers,
                  ungenotyped_markers).size == ts_ref_filtered.num_sites


In [None]:
genotyped_site_pos = ts_ref_filtered.sites_position[genotyped_markers]
ungenotyped_site_pos = ts_ref_filtered.sites_position[ungenotyped_markers]


In [None]:
print(f"Reference markers: {ts_ref_filtered.num_sites}")
print(f"Genotyped markers: {len(genotyped_markers)}")
print(f"Ungenotyped markers: {len(ungenotyped_markers)}")


In [None]:
# Prepare query haplotypes
# WARN: Extracting query haplotypes like this only works when using ACGT encoding.
ts_query = ts_full.simplify(idx_query_haps, filter_sites=False)
ts_query


In [None]:
# Filter sites in query haplotypes down to reference markers.
remove_sites = np.where(np.isin(ts_query.sites_position, ts_ref_filtered.sites_position, invert=True))[0]
ts_query_filtered = ts_query.delete_sites(site_ids=remove_sites)
assert ts_query_filtered.num_sites == ts_ref_filtered.num_sites
assert np.array_equal(ts_query_filtered.sites_position, ts_ref_filtered.sites_position)
ts_query_filtered


In [None]:
# Unmasked query haplotypes
ts_query_h = ts_query_filtered.genotype_matrix(alleles=tskit.ALLELES_ACGT)
print(ts_query_h.shape)
ts_query_h


In [None]:
# Masked query haplotypes
ts_query_h_masked = np.copy(ts_query_h)
ts_query_h_masked[ungenotyped_markers, :] = -1
ts_query_h_masked
assert ts_query_h.shape == ts_query_h_masked.shape


In [None]:
with open(npy_query_file, "wb") as f:
    np.save(f, ts_query_h)
    np.save(f, ts_query_h_masked)
    np.save(f, genotyped_site_idx)
    np.save(f, ungenotyped_site_idx)
    np.save(f, genotyped_site_pos)
    np.save(f, ungenotyped_site_pos)


In [None]:
ts_full.dump(ts_full_file)
ts_ref_filtered.dump(ts_ref_file)


In [None]:
# Prepare files for BEAGLE 4.1
import gzip
with gzip.open(prefix + ".ref.vcf.gz", "wt") as f:
    ts_ref_filtered.write_vcf(f)
site_mask = np.zeros(ts_ref_filtered.num_sites, dtype=bool)
site_mask[ungenotyped_markers] = True
assert np.sum(site_mask) == len(ungenotyped_markers)
with gzip.open(prefix + ".query.vcf.gz", "wt") as f:
    ts_query_filtered.write_vcf(f, site_mask=site_mask)
