In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install toolz scikit-allel

Collecting scikit-allel
  Downloading scikit_allel-1.3.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.4/7.4 MB[0m [31m20.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: scikit-allel
Successfully installed scikit-allel-1.3.6


## Setup

In [3]:
import os
# os.environ["MODIN_CPUS"] = "8"
# from distributed import Client
# client = Client()
import glob
import numpy as np
import math
import re
import random
import shutil
import gzip
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.utils import resample
from sklearn import metrics
from sklearn.metrics import confusion_matrix
from sklearn.metrics import f1_score
# import allel
from scipy.spatial.distance import squareform
from scipy.special import softmax
%matplotlib inline
from toolz import interleave
from tqdm import tqdm
from matplotlib import pyplot as plt
import tensorflow_datasets as tfds
from sklearn.metrics import mean_squared_error
from sklearn.linear_model import LassoCV, ElasticNetCV
from sklearn.model_selection import KFold,StratifiedKFold

## Prepare the data

In [4]:
class DataLoader:
    """
    If the reference is unphased, cannot handle phased target data, so the valid (ref, target) combinations are:
    (phased, phased), (phased, unphased), (unphased, unphased)
    Important note: for each case, the model should be trained separately
    """
    def __init__(self, reference_panel_file_path, target_file_path):
        self.ref_n_header_lines = []
        self.ref_n_data_header = ""
        self.map_values_1_vec = np.vectorize(self.map_hap_2_ind_parent_1)
        self.map_values_2_vec = np.vectorize(self.map_hap_2_ind_parent_2)
        print("Rading the reference file...")
        # get header
        root, ext = os.path.splitext(reference_panel_file_path)
        with gzip.open(reference_panel_file_path, 'rt') if ext == '.gz' else open(reference_panel_file_path, 'rt') as f_in:
            # skip info
            while True:
                line = f_in.readline()
                if line.startswith("##"):
                    self.ref_n_header_lines.append(line)
                else:
                    self.ref_n_data_header = line
                    break
        self.reference_panel = pd.read_csv(reference_panel_file_path,
                                           comment='#',
                                           sep='\t',
                                           names=self.ref_n_data_header.strip().split('\t'))
        self.VARIANT_COUNT = self.reference_panel.shape[0]
        print(f"{self.VARIANT_COUNT} variants found. Done!")
        print("Rading the target file...")
        self.target_n_header_lines = []
        self.target_n_data_header = ""
        root, ext = os.path.splitext(target_file_path)
        # get header
        with gzip.open(target_file_path, 'rt') if ext == '.gz' else open(target_file_path, 'rt') as f_in:
            # skip info
            while True:
                line = f_in.readline()
                if line.startswith("##"):
                    self.target_n_header_lines.append(line)
                else:
                    self.target_n_data_header = line
                    break
        real_target_set = pd.read_csv(target_file_path,
                                           comment='#',
                                           sep='\t',
                                           names=self.target_n_data_header.strip().split('\t'),)
        print(f"{real_target_set.shape[0]} variants found. Done!")
        target_is_phased = "|" in real_target_set.iloc[0, 10]
        ref_is_phased = "|" in self.reference_panel.iloc[0, 10]
        self.is_phased = target_is_phased and ref_is_phased
        print("Creating the new target dataframe")
        self.target_set = real_target_set.merge(self.reference_panel["ID"], on='ID', how='right')
        self.target_set[self.reference_panel.columns[:9]] = self.reference_panel[self.reference_panel.columns[:9]]
        self.target_set.fillna(".|." if self.is_phased else "./.", inplace=True)
        print("Extracting genotype information...")
        SEP = "|" if self.is_phased else "/"
        def get_num_allels(g):
            v1, v2 = g.split(SEP)
            return max(int(v1), int(v2)) + 1

        def key_gen(v1, v2):
            return f"{v1}{SEP}{v2}"

        genotype_vals = np.unique(self.reference_panel.iloc[:, 9:].values)
        if target_is_phased != ref_is_phased:
            phased_to_unphased_dict = {}
            for i in range(genotype_vals.shape[0]):
                key = genotype_vals[i]
                v1, v2 = [int(s) for s in genotype_vals[i].split("|")]
                genotype_vals[i] = f"{min(v1, v2)}{SEP}{max(v1, v2)}"
                phased_to_unphased_dict[key] = genotype_vals[i]
            self.reference_panel.replace(phased_to_unphased_dict, inplace=True)
        genotype_vals = np.unique(genotype_vals)
        allele_count = max(map(get_num_allels, genotype_vals))
        if self.is_phased:
            self.hap_map = {str(i): i for i in range(allele_count)}
            self.hap_map.update({".": allele_count})
            self.r_hap_map = {i:k for k, i in self.hap_map.items()}
            self.map_preds_2_allele = np.vectorize(lambda x: self.r_hap_map[x])
        self.MISSING_VALUE = self.SEQ_DEPTH = allele_count + 1 if self.is_phased else len(genotype_vals) + 1
        self.genotype_keys = np.array([key_gen(i,j) for i in range(allele_count) for j in range(allele_count)]) if self.is_phased else genotype_vals
        self.genotype_keys = np.hstack([self.genotype_keys, [".|."] if self.is_phased else ["./."]])
        self.replacement_dict = {g:i for i,g in enumerate(self.genotype_keys)}
        self.reverse_replacement_dict = {i:g for g,i in self.replacement_dict.items()}

    def map_hap_2_ind_parent_1(self, x):
        return self.hap_map[x.split('|')[0]]

    def map_hap_2_ind_parent_2(self, x):
        return self.hap_map[x.split('|')[1]]

    def __get_forward_data(self, data: pd.DataFrame):
        if self.is_phased:
            # break it into haplotypes
            _x = np.empty((data.shape[1] * 2, data.shape[0]), dtype=np.int32)

            _x[0::2] = self.map_values_1_vec(data.values.T)
            _x[1::2] = self.map_values_2_vec(data.values.T)
            return _x
        else:
            return data.replace(self.replacement_dict).values.T.astype(np.int32)

    def get_ref_set(self, starting_var_index=None, ending_var_index=None):
        if starting_var_index>=0 and ending_var_index>=starting_var_index:
            return self.__get_forward_data(self.reference_panel.iloc[starting_var_index:ending_var_index, 9:])
        else:
            print("No variant indices provided or indices not valid, using the whole sequence...")
            return self.__get_forward_data(self.reference_panel.iloc[:, 9:])

    def get_target_set(self, starting_var_index=None, ending_var_index=None):
        if starting_var_index>=0 and ending_var_index>=starting_var_index:
            return self.__get_forward_data(self.target_set.iloc[starting_var_index:ending_var_index, 9:])
        else:
            print("No variant indices provided or indices not valid, using the whole sequence...")
            return self.__get_forward_data(self.target_set.iloc[:, 9:])

    def convert_haps_to_genotypes(self, allele_probs):
      '''output format: GT:DS:GP'''
      FORMAT = "GT:DS:GP"
      n_haploids, n_variants, n_alleles = allele_probs.shape
      allele_probs_normalized = softmax(allele_probs, axis=-1)

      if n_haploids % 2 != 0:
          raise ValueError("Number of haploids should be even.")

      n_samples = n_haploids // 2
      genotypes = np.zeros((n_samples, n_variants), dtype=object)

      for i in tqdm(range(n_samples)):
        haploid_1 = allele_probs_normalized[2 * i]
        haploid_2 = allele_probs_normalized[2 * i + 1]

        for j in range(n_variants):
          phased_probs = np.multiply.outer(haploid_1[j], haploid_2[j]).flatten()
          unphased_probs = np.array([phased_probs[0], sum(phased_probs[1:3]), phased_probs[-1]])
          unphased_probs_str = ",".join([f"{v:.6f}" for v in unphased_probs])
          alt_dosage = np.dot(unphased_probs, [0, 1, 2])
          variant_genotypes = [str(v) for v in np.argmax(allele_probs_normalized[i*2:(i+1)*2, j], axis=-1)]
          genotypes[i, j] = '|'.join(variant_genotypes) + f":{alt_dosage:.3f}:{unphased_probs_str}"

      new_vcf = self.target_set.copy()
      new_vcf.iloc[:n_variants, 9:] = genotypes.T
      new_vcf["FORMAT"] = FORMAT
      new_vcf["QUAL"] = "."
      new_vcf["FILTER"] = "."
      new_vcf["INFO"] = "IMPUTED"
      return new_vcf

    def convert_unphased_probs_to_genotypes(self, allele_probs):
      '''output format: GT:DS:GP'''
      FORMAT = "GT:DS:GP"
      n_samples, n_variants, n_alleles = allele_probs.shape
      allele_probs_normalized = softmax(allele_probs, axis=-1)
      genotypes = np.zeros((n_samples, n_variants), dtype=object)

      for i in tqdm(range(n_samples)):
          for j in range(n_variants):
              unphased_probs = allele_probs_normalized[i, j]
              unphased_probs_str = ",".join([f"{v:.6f}" for v in unphased_probs])
              alt_dosage = np.dot(unphased_probs, [0, 1, 2])
              variant_genotypes = np.vectorize(self.reverse_replacement_dict.get)(np.argmax(unphased_probs, axis=-1)).flatten()
              genotypes[i, j] = '/'.join(variant_genotypes) + f":{unphased_probs_str}:{alt_dosage:.3f}"

      new_vcf = self.target_set.copy()
      new_vcf.iloc[:, 9:] = genotypes.T
      new_vcf["FORMAT"] = FORMAT
      new_vcf["QUAL"] = "."
      new_vcf["FILTER"] = "."
      new_vcf["INFO"] = "IMPUTED"
      return new_vcf

    def __get_headers_for_output(self):
      headers = ["##fileformat=VCFv4.2",
           '''##source=STI v1.0.0''',
           '''##INFO=<ID=IMPUTED,Number=0,Type=Flag,Description="Marker was imputed">''',
           '''##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">''',
           '''##FORMAT=<ID=DS,Number=A,Type=Float,Description="Estimated Alternate Allele Dosage : [P(0/1)+2*P(1/1)]">''',
           '''##FORMAT=<ID=GP,Number=G,Type=Float,Description="Estimated Posterior Probabilities for Genotypes 0/0, 0/1 and 1/1">''']
      return headers

    def preds_to_genotypes(self, preds):
        """
        WARNING: This only supports bi-allelic data right now!
        :param preds: numpy array of (n_samples, n_variants, n_alleles)
        :return: numpy array of the same shape, with genotype calls, e.g., "0/1"
        """
        if self.is_phased:
          return self.convert_haps_to_genotypes(preds)
        else:
          return self.convert_unphased_probs_to_genotypes(preds)

    def write_ligated_results_to_vcf(self, df, file_name):
      with gzip.open(file_name, 'wt') if file_name.endswith(".gz") else open(file_name, 'wt') as f_out:
          # write info
          f_out.write("\n".join(self.__get_headers_for_output())+"\n")
      df.to_csv(file_name, sep="\t", mode='a', index=False)

In [5]:
root_data_dir = '[data_path]'
train_file_name = "beadchip_reference_all_minaf_05_snps_hwe_1e-2_filtered_train.vcf.gz"
test_file_name = "test_data_beadchip_hwe_filtered.vcf.gz"

In [6]:
dl = DataLoader(root_data_dir + train_file_name,
                root_data_dir + test_file_name)

Rading the reference file...
31143 variants found. Done!
Rading the target file...
7336 variants found. Done!
Creating the new target dataframe
Extracting genotype information...


In [7]:
np.unique(dl.get_target_set(0, 31143).astype(np.int32))

array([0, 1, 2], dtype=int32)

In [8]:
dl.is_phased

True

In [9]:
dl.hap_map

{'0': 0, '1': 1, '.': 2}

## Ligate

In [10]:
assay_name = "beadchip"
files_prefix = f"{assay_name}_probs_window"
save_dir = f"[save_dir]"
output_path = save_dir + f"STI_{assay_name}_ligated_hmr.vcf.gz"

## Use this cell if your model is outputting several windows of variants

In [None]:
# exp_files_prefix = f"{assay_name}_exp_probs_window"
chunk_paths = glob.glob(save_dir+files_prefix+"*")
all_nparrays = []
# for cp in tqdm(sorted(chunk_paths)):
for i in range(16):
  # print(cp)
  print(f"[save_dir]/beadchip_probs_window_{i+1}.npy")
  all_nparrays.append(np.load(f"[save_dir]/beadchip_probs_window_{i+1}.npy"))
preds = np.hstack(all_nparrays)
preds.shape

./drive/MyDrive/ShiLab/beadchip/STI/beadchip_probs_window_1.npy
./drive/MyDrive/ShiLab/beadchip/STI/beadchip_probs_window_2.npy
./drive/MyDrive/ShiLab/beadchip/STI/beadchip_probs_window_3.npy
./drive/MyDrive/ShiLab/beadchip/STI/beadchip_probs_window_4.npy
./drive/MyDrive/ShiLab/beadchip/STI/beadchip_probs_window_5.npy
./drive/MyDrive/ShiLab/beadchip/STI/beadchip_probs_window_6.npy
./drive/MyDrive/ShiLab/beadchip/STI/beadchip_probs_window_7.npy
./drive/MyDrive/ShiLab/beadchip/STI/beadchip_probs_window_8.npy
./drive/MyDrive/ShiLab/beadchip/STI/beadchip_probs_window_9.npy
./drive/MyDrive/ShiLab/beadchip/STI/beadchip_probs_window_10.npy
./drive/MyDrive/ShiLab/beadchip/STI/beadchip_probs_window_11.npy
./drive/MyDrive/ShiLab/beadchip/STI/beadchip_probs_window_12.npy
./drive/MyDrive/ShiLab/beadchip/STI/beadchip_probs_window_13.npy
./drive/MyDrive/ShiLab/beadchip/STI/beadchip_probs_window_14.npy
./drive/MyDrive/ShiLab/beadchip/STI/beadchip_probs_window_15.npy
./drive/MyDrive/ShiLab/beadchip/ST

(200, 31143, 2)

In [12]:
genotypes = dl.preds_to_genotypes(preds)
genotypes

100%|██████████| 100/100 [01:46<00:00,  1.06s/it]


Unnamed: 0,#CHROM,POS,ID,REF,ALT,QUAL,FILTER,INFO,FORMAT,NA20534,...,HG03985,NA18545,HG01342,HG02154,HG02232,NA19102,NA18526,HG02885,HG01139,HG03695
0,22,16052986,22:16052986:C:A,C,A,.,.,IMPUTED,GT:DS:GP,"0|0:0.590:0.497031,0.415947,0.087023",...,"0|0:0.590:0.497031,0.415947,0.087023","0|0:0.590:0.497031,0.415947,0.087023","0|0:0.590:0.497031,0.415947,0.087023","0|0:0.590:0.497031,0.415947,0.087023","0|0:0.590:0.497031,0.415947,0.087023","0|0:0.590:0.497031,0.415947,0.087023","0|0:0.590:0.497031,0.415947,0.087023","0|0:0.590:0.497031,0.415947,0.087023","0|0:0.590:0.497031,0.415947,0.087023","0|0:0.590:0.497031,0.415947,0.087023"
1,22,16053444,22:16053444:A:T,A,T,.,.,IMPUTED,GT:DS:GP,"0|0:0.584:0.501467,0.413353,0.085180",...,"0|0:0.584:0.501467,0.413353,0.085180","0|0:0.584:0.501467,0.413353,0.085180","0|0:0.584:0.501467,0.413353,0.085180","0|0:0.584:0.501467,0.413353,0.085180","0|0:0.584:0.501467,0.413353,0.085180","0|0:0.584:0.501467,0.413353,0.085180","0|0:0.584:0.501467,0.413353,0.085180","0|0:0.584:0.501467,0.413353,0.085180","0|0:0.584:0.501467,0.413353,0.085180","0|0:0.584:0.501467,0.413353,0.085180"
2,22,16053791,22:16053791:C:A,C,A,.,.,IMPUTED,GT:DS:GP,"0|0:0.682:0.434482,0.449341,0.116177",...,"0|0:0.682:0.434482,0.449341,0.116177","0|0:0.682:0.434482,0.449341,0.116177","0|0:0.682:0.434482,0.449341,0.116177","0|0:0.682:0.434482,0.449341,0.116177","0|0:0.682:0.434482,0.449341,0.116177","0|0:0.682:0.434482,0.449341,0.116177","0|0:0.682:0.434482,0.449341,0.116177","0|0:0.682:0.434482,0.449341,0.116177","0|0:0.682:0.434482,0.449341,0.116177","0|0:0.682:0.434482,0.449341,0.116177"
3,22,16055942,22:16055942:C:T,C,T,.,.,IMPUTED,GT:DS:GP,"1|1:1.318:0.116218,0.449379,0.434404",...,"1|1:1.318:0.116218,0.449379,0.434404","1|1:1.318:0.116218,0.449379,0.434404","1|1:1.318:0.116218,0.449379,0.434404","1|1:1.318:0.116218,0.449379,0.434404","1|1:1.318:0.116218,0.449379,0.434404","1|1:1.318:0.116218,0.449379,0.434404","1|1:1.318:0.116218,0.449379,0.434404","1|1:1.318:0.116218,0.449379,0.434404","1|1:1.318:0.116218,0.449379,0.434404","1|1:1.318:0.116218,0.449379,0.434404"
4,22,16058758,22:16058758:C:A,C,A,.,.,IMPUTED,GT:DS:GP,"0|0:0.681:0.434878,0.449150,0.115972",...,"0|0:0.681:0.434878,0.449150,0.115972","0|0:0.681:0.434878,0.449150,0.115972","0|0:0.681:0.434878,0.449150,0.115972","0|0:0.681:0.434878,0.449150,0.115972","0|0:0.681:0.434878,0.449150,0.115972","0|0:0.681:0.434878,0.449150,0.115972","0|0:0.681:0.434878,0.449150,0.115972","0|0:0.681:0.434878,0.449150,0.115972","0|0:0.681:0.434878,0.449150,0.115972","0|0:0.681:0.434878,0.449150,0.115972"
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
31138,22,51233182,22:51233182:G:A,G,A,.,.,IMPUTED,GT:DS:GP,"0|0:0.779:0.364262,0.492383,0.143355",...,"0|1:0.901:0.291665,0.515831,0.192504","0|0:0.618:0.476901,0.428477,0.094622","0|0:0.652:0.454201,0.439491,0.106309","0|0:0.646:0.457918,0.438192,0.103890","0|0:0.680:0.433727,0.452473,0.113799","0|0:0.791:0.359252,0.490259,0.150489","0|0:0.624:0.473219,0.429708,0.097073","0|1:0.810:0.340886,0.508438,0.150676","0|0:0.742:0.394505,0.469266,0.136229","0|0:0.756:0.383289,0.477340,0.139371"
31139,22,51233312,22:51233312:A:G,A,G,.,.,IMPUTED,GT:DS:GP,"0|0:0.700:0.417376,0.465478,0.117147",...,"0|1:0.854:0.310230,0.525452,0.164318","0|0:0.591:0.495762,0.417147,0.087091","0|0:0.574:0.508013,0.409476,0.082512","0|0:0.571:0.510278,0.408119,0.081603","0|0:0.618:0.476164,0.429209,0.094627","0|0:0.629:0.469392,0.432556,0.098051","0|0:0.562:0.516767,0.404245,0.078989","0|0:0.660:0.446704,0.446192,0.107104","0|0:0.598:0.491303,0.419259,0.089438","0|0:0.712:0.411740,0.464736,0.123524"
31140,22,51233347,22:51233347:T:C,T,C,.,.,IMPUTED,GT:DS:GP,"0|0:0.751:0.380414,0.488556,0.131030",...,"0|1:0.868:0.302100,0.527790,0.170110","0|0:0.593:0.494266,0.418373,0.087360","0|0:0.579:0.504981,0.411285,0.083734","0|0:0.565:0.514622,0.405506,0.079872","0|0:0.598:0.491172,0.420089,0.088739","0|0:0.682:0.431482,0.454726,0.113792","0|0:0.564:0.515223,0.405224,0.079552","0|0:0.677:0.435891,0.451631,0.112478","0|0:0.623:0.474274,0.428810,0.096916","0|0:0.753:0.384101,0.479203,0.136695"
31141,22,51235979,22:51235979:G:A,G,A,.,.,IMPUTED,GT:DS:GP,"0|0:0.785:0.368889,0.477235,0.153876",...,"0|0:0.724:0.406733,0.462822,0.130444","0|0:0.651:0.454213,0.440361,0.105426","0|0:0.780:0.372146,0.475791,0.152063","0|0:0.797:0.361684,0.479565,0.158751","0|0:0.803:0.356894,0.483237,0.159869","0|0:0.674:0.439178,0.447564,0.113257","0|0:0.760:0.384311,0.471235,0.144454","0|0:0.696:0.423546,0.456953,0.119500","0|0:0.663:0.446514,0.443767,0.109720","0|0:0.719:0.408519,0.463703,0.127778"


In [13]:
dl.write_ligated_results_to_vcf(genotypes, output_path)

In [None]:
from google.colab import runtime
runtime.unassign()