In [1]:
# This notebook implements the last step in SDrecall to merge GATK variants with SDrecall variants.

# -- Required user-defined parameters -- # 
# 1. VCF generated with SDrecall
# 2. VCF generated with other algorithms (eg. GATK)
# 3. Sample ID (usually sample name) 
# 4. PED file of samples

# -- Main steps -- #
# A. Check if SDrecall VCF contains > 10 variants
# B. Make sure SDrecall VCF is in diploidy format
# C. Merge variants with priority
#  > Usage: merge_variants_with_priority.py -ov {SDrecall VCF} -pv {GATK VCF} -op {sample_ID.merged.vcf.gz}
# D. Output VCF should have > 10000 variants 
# E. Mark homoseq recall shortv & tabix indexing

# -- Ideas -- #
# 1. Create a class "Variant"
# 2. Version sort for chromosome order https://stackoverflow.com/questions/2574080/sorting-a-list-of-dot-separated-numbers-like-software-versions



In [2]:
import os
import pandas as pd
import logging
from src.utils import *
from pandarallel import pandarallel as pa
import time

import sys

ov_vcf = "/home/louisshe/shortVariantVCF/data/merge_vcfs/PID21-055.homo_region.filtered.vcf.gz"
pv_vcf = "/home/louisshe/shortVariantVCF/data/merge_vcfs/PID21-055.gatk.g.vcf.gz.chr13"

In [3]:
class Variant:
    """
    This class creates a Variant object for each record in a graph.
    
    This is designated for VCFs with one sample.
    
    """
    
    import pandas as pd
    from difflib import SequenceMatcher
    
    def __init__(self, row, idx):
        
        """
        Variant constructor based on a pandas dataframe
        
        Note:
        -------
        Rows should be taken from a dataframe sorted by chromosome (chrM, chr1, ..., chrX, chrY), or by the order in fai file
        
        Error:
        -------
        KeyError: No INSLEN in INS records
        ValueError: No sample calls
        
        """
        
        self.CHROM = row["#CHROM"]
        self.START = int(row["POS"])
        self.ID = row["ID"]
        self.FILTER = row["FILTER"]
        self.QUAL = row["QUAL"]
        self.REF = row["REF"]
        self.ALT = row["ALT"].split(",")[0] # Only consider the biallelic case
        self.idx = int(row["POS"])
        
        # Data from INFO field
        self._INFO = {field.split("=")[0]: field.split("=")[1:] for field in row["INFO"].split(";")}
        self.END = int(self._INFO["END"][0]) if "END" in self._INFO.keys() else row["POS"]
        self.VARTYPE = self._INFO["SVTYPE"][0] if "SVTYPE" in self._INFO.keys() else "short_variant"
        if self.VARTYPE == "INS":
            try:
                self.INSEND = self.START + int(self._INFO["INSLEN"][0])
            except KeyError:
                raise KeyError(f"No INSLEN field in INS variant.\n{row} ")
            self.INSSEQ = self._INFO["INSSEQ"] if "INSSEQ" in self._INFO.keys() else "N"
        
        # Data from sample field
        if row.shape[0] - 9 > 1:
            logging.warning(f"Multiple samples found. Only consider the first sample. ")
        elif row.shape[0] == 9:
            raise ValueError(f"No sample found in the given VCF for the variants.\n{self.__repr__()}")        
        
        _sample = [row["FORMAT"].split(":")] + [row[col].split(":") for col in range(9, row.shape[0], 1)]
        self.SAMPLE = {sample[0]: sample[1] for sample in zip(*_sample)}
        try:
            self.DP = float(self.SAMPLE["DP"])
        except (KeyError, TypeError): # Case: DP not available or DP = "."
            self.DP = "."
        
        try:
            self.GT = self.SAMPLE["GT"] if self.DP != "." and self.DP != 0.0 else "."
        except:
            raise KeyError(f"GT not found in the variant:\n{self.__repr__()}")
        
        self.MISSING_GT = (self.GT == "./.") or (self.GT == ".|.") or (self.GT == ".")
            
        if "AD" in self.SAMPLE.keys():
            self.AD = list(map(int, self.SAMPLE["AD"].split(",")))
            if self.AD[0] > 0 and self.AD == 0:
                self.GT = "0"
        
    def __repr__(self):
        """
        String representation of the variant.
        """
        return f"{self.CHROM}:{self.START}-{self.END} {self.REF}>{self.ALT}\nType: {self.VARTYPE}\nGenotype: {self.GT}\nSample: {self.SAMPLE}"
    
    def getSeries(self):

        keys = ["#CHROM", "POS", "ID", "REF", "ALT", "QUAL", "FILTER", "INFO", "FORMAT", row.index[9]]
        try:
            values = [self.CHROM, self.START, self.ID, self.REF, self.ALT, self.QUAL, self.FILTER, 
                 ";".join([f"{key}={value[0]}" for key, value in self._INFO.items()]),
                 ":".join(self.SAMPLE.keys()),
                 ":".join(self.SAMPLE.values())]
        except:
            values = [self.CHROM, self.START, self.ID, self.REF, self.ALT, self.QUAL, self.FILTER, 
                 ".",
                 ":".join(self.SAMPLE.keys()),
                 ":".join(self.SAMPLE.values())]

        return pd.Series(values, index=keys)
    
    def addFilter(self, *tag):
        
        tags = list(tag)
        if self.FILTER == "PASS":
            self.FILTER = ";".join(tags)
        else:
            of = self.FILTER.split(";") + tags
            self.FILTER = ";".join(of)     
    
    def isAdjacent(self, other):
        
        s_thresh = 1
        
        if self == other:
            return True
        else:
            l, r = sorted([self, other])
            return (r.START - l.END <= s_thresh)     
    
    def overlap_fraction(self, other):
        
        l, r = sorted([self, other])
        if r.START < l.END:
            return abs((r.START - l.END) / (r.END - l.START))
        else:
            return 0.0
    
    def __eq__(self, other):
        
        f_thresh = 0.95
        
        if (self.VARTYPE != other.VARTYPE) or (self.CHROM != other.CHROM):
            return False
        elif self.VARTYPE == other.VARTYPE == "short_variant":
            return (self.START == other.START) and (self.END == other.END) \
                    and (self.REF == other.REF) and (self.ALT == other.ALT)
        elif self.VARTYPE == other.VARTYPE == "BND":
            return self.isAdjacent(other)
        else:
            if self.VARTYPE != "INS":
                return (self.overlap_fraction(other) >= f_thresh)
            elif not self.isAdjacent(other):
                return False
            else:
                ofrac = self.overlap_fraction(other)
                if ofrac == 0.0:
                    return False
                elif self.INSSEQ != "N" and other.INSSEQ != "N":
                    seqsim = SequenceMatcher(None, self.INSSEQ, other.INSSEQ).ratio()
                    return (seqsim >= f_thresh)
                else:
                    return (ofrac >= f_thresh)

    def __ne__(self, other):
        return not (self == other)
    
    def __lt__(self, other):
        return self.idx < other.idx
    
    def __le__(self, other):
        return self.idx <= other.idx
    
    def __gt__(self, other):
        return self.idx > other.idx
    
    def __ge__(self, other):
        return self.idx >= other.idx

def bsearch(var, low, high, x) -> tuple:
    """
    Binary search function for finding SDrecall variants.
    
    Return:
    -------
    Found: Variant object, index of the object
    Not found: None
    
    """
    if high >= low:
        mid = (high + low) // 2
        if var[mid] == x:
            return var[mid], mid
        elif var[mid] > x:
            return bsearch(var, low, mid - 1, x)
        else:
            return bsearch(var, mid + 1, high, x)
    else:
        return None, None

def pick_rep_rec(row, ov_vars, pv_tag, ov_tag) -> pd.Series:
    """
    This function processes variants which are found in both VCFs.
    """  
    pv_var = Variant(row, row.name)
    ov_var, idx = bsearch(ov_vars, 0, len(ov_vars) - 1, pv_var)
    if ov_var:
        dup_ov_vars.append(idx) # Keep indices of found items
        if pv_var.GT == ov_var.GT:
            pv_var.addFilter(ov_tag, pv_tag)
            return pv_var.getSeries()
        elif "0" in pv_var.GT or "." in pv_var.GT:
            ov_var.addFilter(ov_tag, pv_tag)
            return ov_var.getSeries()
        elif "0" in ov_var.GT or "." in ov_var.GT:
            pv_var.addFilter(ov_tag, pv_tag)
            return pv_var.getSeries()
        else:
            pv_var.addFilter(ov_tag, pv_tag)
            return pv_var.getSeries()
    else: # Not found in SDrecall VCF
        if "0" in pv_var.GT or "." in pv_var.GT:
            return None
        else:
            pv_var.addFilter(pv_tag)
            return pv_var.getSeries()
        
def concat_headers(header_1, header_2) -> list[str]:
    """
    This function takes two VCF headers and returns a merged header. 
    """
    # Get original order
    fields = [ field.split("=")[0].strip("#") for field in header_1 ] + [ field.split("=")[0].strip("#") for field in header_2 ]
    _unique_fields = list(dict.fromkeys(fields))
    order = dict(zip(_unique_fields, range(len(_unique_fields))))
    
    # Merge headers  
    merged_header = sorted(list(set(header_1 + header_2)), key=lambda x: order[x.split("=")[0].strip("#")])
    
    return merged_header
    

In [6]:
def merge(pv_vcf, ov_vcf, pv_tag, ov_tag):
    """
    This function merges prioritized VCF with original VCF on a chromosome-wise basis.
    """
    ov_header, ov_subjects, ov_df = loadVCF(ov_vcf)
    pv_header, pv_subjects = loadVCF(pv_vcf, omit_record=True)

    # Get chromosome number
    chrom = pv_vcf.split(".")[-1]
    logging.info(f"Processing chromosome {chrom} ... ")
    
    # Prepare ov variants
    ov_vars = []
    dup_ov_vars = []
    for idx, row in ov_df[ov_df["#CHROM"] == chrom].iterrows():
        ov_vars.append(Variant(row, idx))
    ov_vars = sorted(ov_vars)

    header = ov_df.columns
    processed = 0
    from_pv = pd.DataFrame()
    with pd.read_csv(pv_vcf, sep="\t", na_filter=False, engine="c", 
                     comment="#", header=None, names=header, compression="gzip", 
                     skiprows=32959110, chunksize=500000) as reader:
        for pv_chunk in reader:
            processed += 500000
            from_pv_tmp = pv_chunk.parallel_apply(pick_rep_rec, axis=1, args=(ov_vars, pv_tag, ov_tag,)).dropna()
            if from_pv_tmp.shape[0] != 0:
                from_pv = pd.concat([from_pv, from_pv_tmp], axis=0)
            logging.info(f"Processed {processed} variants. ")

    from_ov = ov_df[~ov_df.index.isin(dup_ov_vars)]
    pv_ov = pd.concat([from_pv, from_ov], axis=0).reset_index(drop=True)
    
    return pv_ov

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=17), Label(value='0 / 17'))), HBox…



In [85]:
# Merge head
ov_vcf = "/home/louisshe/shortVariantVCF/data/merge_vcfs/PID21-055.homo_region.filtered.vcf.gz"
pvcf = "/home/louisshe/shortVariantVCF/data/merge_vcfs/PID21-055.gatk.g.vcf.gz"
workers = 8
pv_tag = "GATK"
ov_tag = "SDrecall"

def main(pvcf, ov_vcf, outpath, pv_tag, ov_tag, workers):
    """
    Main function for merging prioritized variants and original variants.
    
    Arguments:
    ----------
    pvcf: Prioritized VCF holding variants from all contigs
    ov_vcf: Original VCF (eg. VCF from SDrecall)
    ov_tag: tag used for variants called from ov_vcf
    pv_tag: tag used for variants called from pvcf
    outpath: absolute path of output VCF (gzipped)
    workers: number of cores to use
        
    """
    start = time.time()
    # Load prioritized VCF and original VCF (as file streams)
    os.chdir(os.path.dirname(pvcf))
    os.makedirs("tmp", exist_ok=True)

    ov_header, ov_subjects, ov_df = loadVCF(ov_vcf)
    pv_header, pv_subjects = loadVCF(pv_vcf, omit_record=True)
    all_chr = [ pcontig[13:].split(",")[0] for pcontig in pv_header if pcontig.startswith("##contig") ]

    for contig in all_chr:
        cmd = f"bcftools filter -r {contig} -Oz -o {os.path.join('tmp/', os.path.basename(pvcf) + '.') + str(contig)} {pvcf} "
        executeCmd(cmd)
    logging.info(f"****************** VCF splitting completed in {time.time() - start:.2f} seconds ******************")
        
    os.chdir("tmp/") # Descend into tmp/
    pv_vcfs = [pv_vcf for pv_vcf in os.listdir()]

    # Write new VCF header
    ov_filter_head = f"##FILTER=<ID={ov_tag},Description='variants called from {ov_tag}'>"
    pv_filter_head = f"##FILTER=<ID={pv_tag},Description='variants called from {pv_tag}'>"
    ov_header.append(ov_filter_head)
    ov_header.append(pv_filter_head)
    merged_header = concat_headers(pv_header, ov_header)
    with gzip.open(outpath) as f:
        f.write("\n".join(merged_header).encode())
        f.write("\n".encode())
        f.write("\t".join(ov_df.columns.tolist()).encode())
        f.write("\n".encode())

    # Process prioritized variants
    pa.initialize(progress_bar=False, verbose=0, nb_workers=workers)

    for pv_vcf in pv_vcfs:
        merged_df = merge(pv_vcf, ov_vcf, pv_tag, ov_tag)
        merged_df.to_csv(outpath, sep="\t", index=False, header=False, mode="a", compression="gzip")

    # Cleanup
    os.chdir(os.path.dirname(pvcf))
    os.remove("tmp/")

if __name__ == "__main__":
    
    # Argparse setup
    parser = argparse.ArgumentParser(description = "Merge prioritized VCF and original VCF.")
    parser._optionals.title = "Options"
    ROOT = os.path.dirname(__file__)
    parser.add_argument("--pvcf", type = str, help = "prioritized VCF (gz)", required = True)
    parser.add_argument("--ovcf", type = str, help = "original VCF (gz)", required = True)
    parser.add_argument("--pv_tag", type = str, help = "tag used for variants from prioritized VCF", required = True)
    parser.add_argument("--ov_tag", type = str, help = "tag used for variants from original VCF", required = True)
    parser.add_argument("--outpath", type = str, help = "absolute output path of merged VCF (gz)", required = True)
    parser.add_argument("--thread", type = int, help = "number of threads (default: 8)", default = 8)
    parser.add_argument("-v", "--verbose", type = str, default = "INFO", help = "verbosity level (default: INFO)")
    args = parser.parse_args()
    logging.basicConfig(format='[%(asctime)s] %(levelname)s: %(message)s', datefmt='%a %b-%m %I:%M:%S%P',
                        level = args.verbose.upper())
    logging.debug(f"Working in {ROOT}")  
    
    main(args.pvcf, args.ovcf, args.outpath, args.pv_tag, args.ov_tag, args.thread)
    

In [84]:
ov_header, ov_subjects, ov_df = loadVCF(ov_vcf)
pv_header, pv_subjects = loadVCF(pv_vcf, omit_record=True)

# Add filters for GATK and SDrecall (pv_tag and ov_tag)



ov_filter_head = f"##FILTER=<ID={ov_tag},Description='variants called from {ov_tag}'>"
pv_filter_head = f"##FILTER=<ID={pv_tag},Description='variants called from {pv_tag}'>"
ov_header.append(ov_filter_head)
ov_header.append(pv_filter_head)
merged_header = concat_headers(pv_header, ov_header)


['##fileformat=VCFv4.2',
 '##FILTER=<ID=PASS,Description="All filters passed">',
 '##FILTER=<ID=LIKELY_INTRINSIC,Description="The variant called is likely due to intrinsic difference between homologous sequences in the ref genome.">',
 '##FILTER=<ID=LowQual,Description="Low quality">',
 '##FILTER=<ID=UNLIKELY_INTRINSIC,Description="The variant called is unlikely due to intrinsic difference between homologous sequences in the ref genome. ALT/REF ratio is higher than theoretical value for wildtype.">',
 '##ALT=<ID=NON_REF,Description="Represents any possible alternative allele not already represented at this location by REF and ALT">',
 '##FORMAT=<ID=GQ,Number=1,Type=Integer,Description="Genotype Quality">',
 '##FORMAT=<ID=MIN_DP,Number=1,Type=Integer,Description="Minimum DP observed within the GVCF block">',
 '##FORMAT=<ID=PGT,Number=1,Type=String,Description="Physical phasing haplotype information, describing how the alternate alleles are phased in relation to one another; will always 

In [58]:
def printRow(row):
    
    print(row.index[9])
#     var = Variant(row, row.name)
#     print(":".join(var.SAMPLE.values()))
#     print(";".join([f"{key}={value[0]}" for key, value in var._INFO.items()]))
#     print(var._INFO.keys())
#     print(Variant(row, row.name).getSeries())
#     print(type(row))
    sys.exit()
ov_df.apply(printRow, axis=1)

PID21-055


SystemExit: 

In [22]:
chrorder = ["chrM", "chr1", "chr2", "chr3", "chr4", "chr5", "chr6", "chr7", "chr8", "chr9", "chr10", 
                         "chr11", "chr12", "chr13", "chr14", "chr15", "chr16", "chr17", "chr18", "chr19", "chr20", 
                         "chr21", "chr22", "chrX", "chrY", "chr1_gl000191_random", "chr1_gl000192_random", 
                         "chr4_ctg9_hap1", "chr4_gl000193_random", "chr4_gl000194_random", 
                         "chr6_apd_hap1", "chr6_cox_hap2", "chr6_dbb_hap3", "chr6_mann_hap4", "chr6_mcf_hap5", "chr6_qbl_hap6", "chr6_ssto_hap7", 
                         "chr7_gl000195_random", "chr8_gl000196_random", "chr8_gl000197_random", 
                         "chr9_gl000198_random", "chr9_gl000199_random", "chr9_gl000200_random", "chr9_gl000201_random", 
                         "chr11_gl000202_random", "chr17_ctg5_hap1", 
                         "chr17_gl000203_random", "chr17_gl000204_random", "chr17_gl000205_random", "chr17_gl000206_random", 
                         "chr18_gl000207_random", "chr19_gl000208_random", "chr19_gl000209_random", "chr21_gl000210_random", 
                         "chrUn_gl000211", "chrUn_gl000212", "chrUn_gl000213", "chrUn_gl000214", "chrUn_gl000215", "chrUn_gl000216", 
                         "chrUn_gl000217", "chrUn_gl000218", "chrUn_gl000219", "chrUn_gl000220", "chrUn_gl000221", "chrUn_gl000222", 
                         "chrUn_gl000223", "chrUn_gl000224", "chrUn_gl000225", "chrUn_gl000226", "chrUn_gl000227", "chrUn_gl000228", 
                         "chrUn_gl000229", "chrUn_gl000230", "chrUn_gl000231", "chrUn_gl000232", "chrUn_gl000233", "chrUn_gl000234", 
                         "chrUn_gl000235", "chrUn_gl000236", "chrUn_gl000237", "chrUn_gl000238", "chrUn_gl000239", "chrUn_gl000240", 
                         "chrUn_gl000241", "chrUn_gl000242", "chrUn_gl000243", "chrUn_gl000244", "chrUn_gl000245", "chrUn_gl000246", 
                         "chrUn_gl000247", "chrUn_gl000248", "chrUn_gl000249"]

In [53]:
class Chromosome:
    
    def __init__(self, name):
        self.name = name
        self.GRCh37 = (self.name.lower().startswith("chr"))
        self.length = len(self.name)
        self.contig = self.name[3:] if self.GRCh37 else self.name
        
    def __hash__(self):
        """
        Hash function for self.contig (eg. M, 1, 2, ..., Un_gl000248)
        """
        contig_map = {
                       "M": 48,  "1": 49,  "2": 50,  "3": 51,  "4": 52,  "5": 53,  "6": 54,  "7": 55,
                       "8": 56,  "9": 57, "10": 58, "11": 59, "12": 60, "13": 61, "14": 62, "15": 63,
                      "16": 64, "17": 65, "18": 66, "19": 67, "20": 68, "21": 69, "22": 70,  "X": 71,
                       "Y": 72, "Un": 73
                     }
        key = contig_map.get(self.contig.split("_")[0], 74)
        if self.contig.split("_")[1:]:
            return key * self.length * sum(ord(letter) for letter in "_".join(self.contig.split("_")[1:]))
        else:
            return key * self.length * 1
    
    def __repr__(self):
        return self.name
    
    def __lt__(self, other):
        return self.__hash__() < other.__hash__()
    
    def __gt__(self, other):
        return self.__hash__() > other.__hash__()
    
    def __le__(self, other):
        return self.__hash__() <= other.__hash__()
    
    def __ge__(self, other):
        return self.__hash__() >= other.__hash__()
    
    def __eq__(self, other):
        return self.__hash__() == other.__hash__()
    
test = Chromosome("chrM")
test2 = Chromosome("chr1")
test == test
print(test)

chrM


#### TODO:
- [ ] merge_vcf_heads
- [ ] pick_rep_rec
- [ ] main_merge

In [55]:
chrom_cls = [ Chromosome(chrom) for chrom in chrorder ]
sorted(chrom_cls)

[chrM,
 chr1,
 chr2,
 chr3,
 chr4,
 chr5,
 chr6,
 chr7,
 chr8,
 chr9,
 chrX,
 chrY,
 chr10,
 chr11,
 chr12,
 chr13,
 chr14,
 chr15,
 chr16,
 chr17,
 chr18,
 chr19,
 chr20,
 chr21,
 chr22,
 chrUn_gl000211,
 chrUn_gl000220,
 chrUn_gl000212,
 chrUn_gl000221,
 chrUn_gl000230,
 chrUn_gl000213,
 chrUn_gl000222,
 chrUn_gl000231,
 chrUn_gl000240,
 chrUn_gl000214,
 chrUn_gl000223,
 chrUn_gl000232,
 chrUn_gl000241,
 chrUn_gl000215,
 chrUn_gl000224,
 chrUn_gl000233,
 chrUn_gl000242,
 chrUn_gl000216,
 chrUn_gl000225,
 chrUn_gl000234,
 chrUn_gl000243,
 chrUn_gl000217,
 chrUn_gl000226,
 chrUn_gl000235,
 chrUn_gl000244,
 chrUn_gl000218,
 chrUn_gl000227,
 chrUn_gl000236,
 chrUn_gl000245,
 chrUn_gl000219,
 chrUn_gl000228,
 chrUn_gl000237,
 chrUn_gl000246,
 chrUn_gl000229,
 chrUn_gl000238,
 chrUn_gl000247,
 chrUn_gl000239,
 chrUn_gl000248,
 chrUn_gl000249,
 chr6_dbb_hap3,
 chr6_apd_hap1,
 chr6_mcf_hap5,
 chr6_qbl_hap6,
 chr6_cox_hap2,
 chr4_ctg9_hap1,
 chr6_mann_hap4,
 chr6_ssto_hap7,
 chr17_ctg5_hap1,
