In [None]:
import pandas as pd
from pathlib import Path
import random
import polars as pl

In [None]:
phenotype_annotation = "/path/to/phenotype.hpoa"
disease_pg = "/path/to/disease.pg"
clinvar_vcf_file_path = "/path/to/clinvar_20240127_hg19.vcf"
gnomad_vcf_file_path = "/path/to/gnomad.exomes.r2.1.1.sites.trimmed.vcf"

In [None]:
class GetCompatibleOMIMDiseases:
    def __init__(self, phenotype_annotation: Path, disease_pg: Path):
        self.phenotype_annotation = phenotype_annotation
        self.disease_pg = disease_pg
        
    def read_phenotype_annotation(self):
        return pd.read_csv(self.phenotype_annotation, sep="\t", comment="#", low_memory=False)
    
    @staticmethod
    def get_compatible_omim_ids_from_phenotype_annotation(phenotype_annotation_df):
        phenotypes = phenotype_annotation_df[phenotype_annotation_df["aspect"] == "P"]
        omim_filtered = phenotypes[phenotypes["database_id"].str.contains("OMIM")]
        omim_filtered = omim_filtered[omim_filtered.groupby('database_id')['database_id'].transform('count') >= 10]
        return list(omim_filtered["database_id"].unique())
    
    def compatible_omim_ids_phenotype_annotation(self):
        return self.get_compatible_omim_ids_from_phenotype_annotation(self.read_phenotype_annotation())
    
    def read_disease_pg(self):
        disease = pl.read_csv(self.disease_pg,separator="|",new_columns=["database_id","gene_mim_number","disease_name","entrez_id","diagnosis_status","inheritance",],has_header=False,)
        return disease.filter(pl.col("database_id").str.starts_with("OMIM"))
    
    @staticmethod
    def get_compatible_omim_ids_from_disease_pg(disease_pg_df):
        filtered = disease_pg_df.filter((pl.col("inheritance") == "R") | (pl.col("inheritance") == "D")| (pl.col("inheritance") == "XD")| (pl.col("inheritance") == "XR")| (pl.col("inheritance") == "B"))
        return list(filtered["database_id"].unique())
    
    def compatible_omim_ids_disease_pg(self):
        return self.get_compatible_omim_ids_from_disease_pg(self.read_disease_pg())
    
    def get_compatible_omim_diseases(self):
        hpoa_omim = self.compatible_omim_ids_phenotype_annotation()
        disease_pg_omim = self.compatible_omim_ids_disease_pg()
        return list(set(hpoa_omim).intersection(disease_pg_omim))
    
    def get_compatible_omim_disease_dict(self):
        disease_df = pl.read_csv(self.disease_pg,separator="|",new_columns=["database_id","gene_mim_number","disease_name","entrez_id","diagnosis_status","inheritance",],has_header=False,)

        hpoa_omim = self.compatible_omim_ids_phenotype_annotation()
        disease_pg_omim = self.compatible_omim_ids_disease_pg()
        diseases = list(set(hpoa_omim).intersection(disease_pg_omim))
        diseases_dict = {}
        for disease in diseases:
            pg = disease_df.filter(pl.col("database_id") == disease)
            for row in pg.rows(named=True):
                diseases_dict.setdefault(disease, []).append(row["entrez_id"])
#                 diseases_dict.update({disease : row["entrez_id"]})
        return diseases_dict
    

In [None]:
compatible_omim = GetCompatibleOMIMDiseases(phenotype_annotation, disease_pg).get_compatible_omim_diseases()

In [None]:
class VariantSelector:
    def __init__(self, clinvar_vcf: Path, compatible_omim_ids: [str]):
        self.clinvar_vcf = clinvar_vcf
        self.compatible_omim_ids = compatible_omim_ids
        self.chromosomes = [str(x) for x in list(range(1,23))]
    
    @staticmethod
    def get_chrom_field(vcf_fields: [str]):
        return vcf_fields[0]
    
    @staticmethod
    def get_pos_field(vcf_fields: [str]):
        return int(vcf_fields[1])
        
    @staticmethod
    def get_ref_field(vcf_fields: [str]):
        return vcf_fields[3]
    
    @staticmethod
    def get_alt_field(vcf_fields: [str]):
        return vcf_fields[4]
  
    @staticmethod
    def get_info_field(vcf_fields: [str]):
        return vcf_fields[7]
    
    def select_all_variants(self):
        selected_variants = []
        with open(self.clinvar_vcf, 'r') as vcf_file:
            for line in vcf_file:
                if line.startswith('#'):  
                    continue
                fields = line.strip().split('\t')
                try:
                    if self.get_chrom_field(fields) in self.chromosomes:
                        selected_variants.append({"CHROM": self.get_chrom_field(fields), "POS": self.get_pos_field(fields), "REF": self.get_ref_field(fields), "ALT": self.get_alt_field(fields)})
                except IndexError:
                    pass
        vcf_file.close()
        return pd.DataFrame(selected_variants)
    
    def select_all_pathogenic_variants(self):
        selected_variants = []
        with open(self.clinvar_vcf, 'r') as vcf_file:
            for line in vcf_file:
                if line.startswith('#'):  
                    continue
                fields = line.strip().split('\t')
                if 'CLNSIG=Pathogenic;' in self.get_info_field(fields) or 'CLNSIG=Likely_pathogenic;' in self.get_info_field(fields):
                    if self.get_chrom_field(fields) in self.chromosomes:
                        clndisdb_field = next(item for item in self.get_info_field(fields).split(';') if item.startswith('CLNDISDB='))
                        omim_id = next((id for id in clndisdb_field.split(',') if 'OMIM:' in id), None)
                        if omim_id and  omim_id in self.compatible_omim_ids:
                            try:
                                gene = [item for item in self.get_info_field(fields).split(';') if item.startswith('GENEINFO=')][0].replace("GENEINFO=", "").split(':')[0]
                                selected_variants.append({"CHROM": self.get_chrom_field(fields), "POS": self.get_pos_field(fields), "REF": self.get_ref_field(fields), "ALT": self.get_alt_field(fields), "OMIM_ID": omim_id, "PATHOGENICTY": "Pathogenic", "GENE": gene})
                            except IndexError:
                                pass

        vcf_file.close()
        return pd.DataFrame(selected_variants)
    

In [None]:
variant_selector = VariantSelector(clinvar_vcf_file_path, compatible_omim)
pathogenic = variant_selector.select_all_pathogenic_variants()
all_clinvar_variants = variant_selector.select_all_variants()

In [None]:
class GnomADVariantSelector:
    def __init__(self, gnomad_vcf: Path, compatible_omim:[str]):
        self.gnomad_vcf = gnomad_vcf
        self.compatible_omim = compatible_omim
        self.chromosomes = [str(x) for x in list(range(1,23))]
    
    @staticmethod
    def get_chrom_field(vcf_fields: [str]):
        return vcf_fields[0]
    
    @staticmethod
    def get_pos_field(vcf_fields: [str]):
        return int(vcf_fields[1])
        
    @staticmethod
    def get_ref_field(vcf_fields: [str]):
        return vcf_fields[3]
    
    @staticmethod
    def get_alt_field(vcf_fields: [str]):
        return vcf_fields[4]
  
    @staticmethod
    def get_info_field(vcf_fields: [str]):
        return vcf_fields[7]
    
    @staticmethod
    def get_filter_field(vcf_fields: [str]):
        return vcf_fields[6]
    
    def select_all_variants(self):
        selected_variants = []
        with open(self.gnomad_vcf, 'r') as vcf_file:
            for line in vcf_file:
                if line.startswith('#'):  
                    continue
                fields = line.strip().split('\t')
                if self.get_chrom_field(fields) in self.chromosomes:
                    if self.get_filter_field(fields) == "PASS":
                        if "AF=" in self.get_info_field(fields):
                            allele_frequency =float([item for item in self.get_info_field(fields).split(';') if item.startswith('AF=')][0].replace("AF=", ""))
                            gene =[item for item in self.get_info_field(fields).split(';') if item.startswith('vep=')][0].split("|")[3]
                            consequence = [item for item in self.get_info_field(fields).split(';') if item.startswith('vep=')][0].split("|")[1]
                            if 0 < allele_frequency < 0.02:
                                selected_variants.append({"CHROM": self.get_chrom_field(fields), "POS": self.get_pos_field(fields), "REF": self.get_ref_field(fields), "ALT": self.get_alt_field(fields), "GENE": gene, "CONSEQUENCE": consequence})
        vcf_file.close()
        return pd.DataFrame(selected_variants)
    
    def filter_gnomad_variants(self, all_clinvar_variants: pd.DataFrame):
        all_gnomad_variants = self.select_all_variants()
        all_gnomad_variants["DB"] = "gnomAD"
        all_clinvar_variants["DB"] = "clinvar"
        merged_variants = pd.concat([all_clinvar_variants, all_gnomad_variants])
        dropped_duplicate_merged_variants = merged_variants.drop_duplicates(subset=["CHROM","POS", "REF", "ALT"])
        filtered_gnomad = dropped_duplicate_merged_variants[dropped_duplicate_merged_variants["DB"] == "gnomAD"]
        diseases = random.choices(self.compatible_omim, k=len(filtered_gnomad))
        gnomad_variants = filtered_gnomad.assign(OMIM_ID=diseases)
        gnomad_variants = gnomad_variants.drop(["DB"], axis=1)
        return gnomad_variants
        

In [None]:
gnomad_variants = GnomADVariantSelector(gnomad_vcf_file_path, compatible_omim).filter_gnomad_variants(all_clinvar_variants)

In [None]:
class VariantBalancer:
    def __init__(self, pathogenic_variants, gnomad_variants):
        self.pathogenic_variants = pathogenic_variants
        self.gnomad_variants = gnomad_variants
        self.selected_pathogenic_variants = pd.DataFrame(columns=pathogenic_variants.columns)
        self.selected_benign_variants = pd.DataFrame(columns=pathogenic_variants.columns)

    
    def get_unique_genes(self):
        return self.pathogenic_variants['GENE'].unique()
    
    def balance_variants_per_gene(self):
        count = 0
        for gene in self.get_unique_genes():
            count +=1
            pathogenic_subset = self.pathogenic_variants[self.pathogenic_variants["GENE"] == gene]
            gnomad_subset = self.gnomad_variants[self.gnomad_variants["GENE"] == gene]
            min_variants = min([len(pathogenic_subset), len(gnomad_subset)])
            pathogenic_selection = pathogenic_subset.sample(n=min_variants, random_state=42)
            self.selected_pathogenic_variants = self.selected_pathogenic_variants.append(pathogenic_selection)
            gnomad_selection = gnomad_subset.sample(n=min_variants, random_state=42)
            self.selected_benign_variants = self.selected_benign_variants.append(gnomad_selection)

    def write_balanced_variants(self):
        self.balance_variants_per_gene()
        self.selected_pathogenic_variants.to_csv("pathogenic_training_variants.tsv", sep="\t", index=False)
        self.selected_benign_variants.to_csv("neutral_training_variants.tsv", sep="\t", index=False)


In [None]:
VariantBalancer(pathogenic, gnomad_variants).write_balanced_variants()