In [None]:
# All code comes from: calculate_gene_similarity.py
import re
import time
import mygene
import pickle
import requests
import scanpy as sc
import multiprocessing as mp

from goatools.obo_parser import GODag
from tqdm import tqdm
from retrying import retry
from collections import defaultdict
from itertools import combinations
from typing import Set
import pandas as pd
import os

mg = mygene.MyGeneInfo()
global_global_ensembl_ids_mapping = {}
global_global_go_terms_mapping = {}
global_global_go_dag = None


def init_worker(ensembl_ids_mapping, go_terms_mapping, go_dag):

    global global_global_ensembl_ids_mapping
    global global_global_go_terms_mapping
    global global_global_go_dag
    global_global_ensembl_ids_mapping = ensembl_ids_mapping
    global_global_go_terms_mapping = go_terms_mapping
    global_global_go_dag = go_dag


def retry_if_ssl_or_connection_error_or_429(exception):
    if isinstance(exception, requests.exceptions.HTTPError) and exception.response.status_code == 429:
        return True
    return isinstance(exception, (requests.exceptions.SSLError, requests.exceptions.ConnectionError))


@retry(retry_on_exception=retry_if_ssl_or_connection_error_or_429, stop_max_attempt_number=5, wait_fixed=2000)
def fetch_go_terms(ensembl_id):
    try:
        result = mg.getgene(ensembl_id, fields='go')
        go_terms = []

        if 'go' in result:
            for category in ['BP', 'MF', 'CC']:
                terms = result['go'].get(category, [])
                if isinstance(terms, dict):
                    go_terms.append(terms['id'])
                elif isinstance(terms, list):
                    go_terms.extend(term['id'] for term in terms)

        time.sleep(10)
        return ensembl_id, go_terms
    except requests.exceptions.HTTPError as e:
        if e.response.status_code == 429:
            time.sleep(1)
            raise
        else:
            raise


def get_ensembl_ids(genes, species):
    ensembl_mapping = defaultdict(list)
    result = mg.querymany(genes, scopes='symbol',
                          fields='ensembl.gene', species=species)
    missing_gene = []
    for item in result:
        ensembl_data = item.get('ensembl', None)
        if ensembl_data:
            if isinstance(ensembl_data, list):
                ensembl_ids = [ens['gene']
                               for ens in ensembl_data if 'gene' in ens]
                ensembl_mapping[item['query']].extend(ensembl_ids)
            elif isinstance(ensembl_data, dict):
                gene_id = ensembl_data.get('gene')
                if gene_id:
                    ensembl_mapping[item['query']].append(gene_id)
                else:
                    ensembl_mapping[item['query']].append('Not Found')
                    missing_gene.append(item['query'])
        else:
            ensembl_mapping[item['query']].append('Not Found')
            missing_gene.append(item['query'])
    if missing_gene:
        print(f"Ensembl IDs not found for the following genes: {missing_gene}")

    return ensembl_mapping, missing_gene


def prefetch_go_terms(ensembl_ids, num_workers=8):
    go_terms_mapping = {}

    with mp.Pool(num_workers) as pool:
        for ensembl_id, go_terms in tqdm(pool.imap_unordered(fetch_go_terms, ensembl_ids), total=len(ensembl_ids), desc="Fetching GO terms"):
            go_terms_mapping[ensembl_id] = go_terms

    return go_terms_mapping


def calculate_gene_similarity(similarity_scores, go_terms1, go_terms2):
    n_i = len(go_terms1)
    n_j = len(go_terms2)

    if n_i == 0 or n_j == 0:
        return 0.0

    n_ij = sum(1 for score in similarity_scores.values() if score > 0)
    if n_ij == 0:
        return 0.0

    similarity_score = sum(similarity_scores.values()) / n_ij
    return similarity_score


def calculate_go_term_similarity(go_terms1, go_terms2, go_dag):
    similarities = defaultdict(float)
    for go1 in go_terms1:
        for go2 in go_terms2:
            if go1 in go_dag and go2 in go_dag:
                if go_dag[go1].namespace == go_dag[go2].namespace:
                    intersect = len(set(go_dag[go1].get_all_parents()) & set(
                        go_dag[go2].get_all_parents()))
                    length1 = go_dag[go1].depth
                    length2 = go_dag[go2].depth
                    if length1 > 0 and length2 > 0:
                        tmk = intersect / max(length1, length2)
                    else:
                        tmk = 0.0
                    similarities[(go1, go2)] = tmk
                else:
                    similarities[(go1, go2)] = 0.0

    return similarities


def compute_gene_pair_similarity(gene_pair):
    gene1, gene2 = gene_pair
    best_similarity = 0.0

    for ensembl1 in global_global_ensembl_ids_mapping.get(gene1, []):
        if ensembl1 == 'Not Found':
            continue
        go_ids1 = global_global_go_terms_mapping.get(ensembl1, [])
        for ensembl2 in global_global_ensembl_ids_mapping.get(gene2, []):
            if ensembl2 == 'Not Found':
                continue
            go_ids2 = global_global_go_terms_mapping.get(ensembl2, [])

            similarity_scores = calculate_go_term_similarity(
                go_ids1, go_ids2, global_global_go_dag)
            similarity = calculate_gene_similarity(
                similarity_scores, go_ids1, go_ids2)

            if similarity > best_similarity:
                best_similarity = similarity

    return (gene1, gene2, best_similarity)


def compute_all_gene_similarities_parallel(gene_pairs, ensembl_ids_mapping, go_terms_mapping, go_dag, output_dir, num_workers=4):
    print(f"Number of CPU cores currently in use: {num_workers}")
    with mp.Pool(num_workers, initializer=init_worker, initargs=(ensembl_ids_mapping, go_terms_mapping, go_dag)) as pool:
        results = list(
            tqdm(pool.imap(compute_gene_pair_similarity, gene_pairs),
                 total=len(gene_pairs), desc="Computing similarities")
        )
    gene_similarity_matrix = defaultdict(dict)
    for gene1, gene2, similarity in results:
        gene_similarity_matrix[gene1][gene2] = similarity
        gene_similarity_matrix[gene2][gene1] = similarity

    return gene_similarity_matrix


def create_dir(directory):
    if not os.path.exists(directory):
        os.makedirs(directory)

In [None]:
GODag_path = "./go-basic.obo"
go_dag = GODag(GODag_path)

genes = ['MUC19', 'ATOH8', 'TLR6', 'AGR3', 'GDF15', 'PGC', 'HBB', 'TGIF1',
         'LPAR1', 'FAM107A', 'PSMB8', 'TRPM5', 'CALML5', 'EMP1', 'CGA', 'CCL19',
         'HIST1H4B', 'SLURP1', 'TSPAN16', 'VGF', 'STOM', 'CPB1', 'XCL2', 'PGAM1',
         'CASP7', 'DSTN', 'AIF1L', 'SCN4A', 'HNF4G', 'MSMB', 'GDPD2', 'HOXB5', 'BCL2L14', 'ITGB6', 'OLFML3', 'FBN2', 'HIST1H1E', 'ANGPT1']

output_dir = "./output"
species = "human"
ensembl_ids_mapping, missing_gene = get_ensembl_ids(
    list(genes), species)

# Extract all valid Ensembl IDs
all_ensembl_ids = set(
    ens_id for gene, ens_ids in ensembl_ids_mapping.items()
    for ens_id in ens_ids if ens_id != 'Not Found'
)

# Prefetch all GO terms
go_terms_mapping = prefetch_go_terms(
    list(all_ensembl_ids), num_workers=80)

# Filter out genes without Ensembl IDs
valid_genes = [gene for gene in genes if ensembl_ids_mapping[gene]
               and 'Not Found' not in ensembl_ids_mapping[gene]]
invalid_genes = list(set(genes) - set(valid_genes))

# Generate gene pairs including only valid genes
gene_pairs = list(combinations(valid_genes, 2))
print(f"Generated {len(gene_pairs)} valid gene pairs in total.")

# Set number of cores for parallel computation
num_workers_compute = mp.cpu_count() - 2
print(
    f"Number of cores used for similarity computation: {num_workers_compute}")

# Compute gene similarities
gene_similarity_matrix = compute_all_gene_similarities_parallel(
    gene_pairs,
    ensembl_ids_mapping,
    go_terms_mapping,
    go_dag,
    output_dir,
    num_workers=num_workers_compute
)

# Fill similarity with 0 for genes without Ensembl IDs
for gene in invalid_genes:
    gene_similarity_matrix[gene] = {
        other_gene: 0.0 for other_gene in genes}
    for other_gene in genes:
        gene_similarity_matrix[other_gene][gene] = 0.0


# Convert results to DataFrame and save as CSV
gene_similarity_matrix_df = pd.DataFrame(gene_similarity_matrix).fillna(0)

./go-basic.obo: fmt(1.2) rel(2024-06-17) 45,494 Terms


2 input query terms found no hit:	['HIST1H4B', 'HIST1H1E']


Ensembl IDs not found for the following genes: ['HIST1H4B', 'HIST1H1E']


Fetching GO terms: 100%|██████████| 44/44 [00:12<00:00,  3.64it/s]


Generated 630 valid gene pairs in total.
Number of cores used for similarity computation: 78
Number of CPU cores currently in use: 78


Computing similarities: 100%|██████████| 630/630 [00:07<00:00, 78.79it/s]


In [5]:
gene_similarity_matrix_df

Unnamed: 0,MUC19,ATOH8,TLR6,AGR3,GDF15,PGC,HBB,TGIF1,LPAR1,FAM107A,...,MSMB,GDPD2,HOXB5,BCL2L14,ITGB6,OLFML3,FBN2,ANGPT1,HIST1H1E,HIST1H4B
ATOH8,0.566667,0.0,0.368534,0.511905,0.376101,0.357307,0.299133,0.381109,0.345658,0.349141,...,0.584058,0.322404,0.454908,0.449447,0.268019,0.502637,0.319801,0.369535,0.0,0.0
TLR6,0.453199,0.368534,0.0,0.563413,0.373368,0.360974,0.267057,0.300455,0.324986,0.317339,...,0.513061,0.297105,0.311649,0.405361,0.219671,0.424657,0.257746,0.336901,0.0,0.0
AGR3,0.4,0.511905,0.563413,0.0,0.622549,0.3375,0.520833,0.483862,0.526355,0.55463,...,0.733333,0.339583,0.440936,0.540741,0.466667,0.533333,0.459524,0.590056,0.0,0.0
GDF15,0.57619,0.376101,0.373368,0.622549,0.0,0.319313,0.309772,0.346627,0.406763,0.402997,...,0.652564,0.367844,0.320798,0.531882,0.254527,0.608373,0.342918,0.398187,0.0,0.0
PGC,0.766667,0.357307,0.360974,0.3375,0.319313,0.0,0.340891,0.275633,0.300993,0.368385,...,0.648148,0.479924,0.367043,0.531939,0.278894,0.629293,0.296658,0.297296,0.0,0.0
HBB,0.581667,0.299133,0.267057,0.520833,0.309772,0.340891,0.0,0.281849,0.304369,0.333568,...,0.566239,0.350926,0.325461,0.458723,0.307848,0.504651,0.297285,0.263718,0.0,0.0
TGIF1,0.488889,0.381109,0.300455,0.483862,0.346627,0.275633,0.281849,0.0,0.31651,0.30903,...,0.53179,0.284008,0.368734,0.374541,0.261219,0.468027,0.286856,0.315033,0.0,0.0
LPAR1,0.54652,0.345658,0.324986,0.526355,0.406763,0.300993,0.304369,0.31651,0.0,0.374206,...,0.569656,0.362883,0.311397,0.480198,0.266646,0.542969,0.315144,0.352291,0.0,0.0
FAM107A,0.55119,0.349141,0.317339,0.55463,0.402997,0.368385,0.333568,0.30903,0.374206,0.0,...,0.585606,0.431635,0.347194,0.505918,0.272867,0.534028,0.319289,0.342069,0.0,0.0
PSMB8,0.453704,0.297478,0.236421,0.415789,0.274343,0.516877,0.340081,0.27162,0.278289,0.308932,...,0.499479,0.38825,0.329003,0.416118,0.317571,0.466912,0.293541,0.23736,0.0,0.0


In [6]:
gene_pairs

[('MUC19', 'ATOH8'),
 ('MUC19', 'TLR6'),
 ('MUC19', 'AGR3'),
 ('MUC19', 'GDF15'),
 ('MUC19', 'PGC'),
 ('MUC19', 'HBB'),
 ('MUC19', 'TGIF1'),
 ('MUC19', 'LPAR1'),
 ('MUC19', 'FAM107A'),
 ('MUC19', 'PSMB8'),
 ('MUC19', 'TRPM5'),
 ('MUC19', 'CALML5'),
 ('MUC19', 'EMP1'),
 ('MUC19', 'CGA'),
 ('MUC19', 'CCL19'),
 ('MUC19', 'SLURP1'),
 ('MUC19', 'TSPAN16'),
 ('MUC19', 'VGF'),
 ('MUC19', 'STOM'),
 ('MUC19', 'CPB1'),
 ('MUC19', 'XCL2'),
 ('MUC19', 'PGAM1'),
 ('MUC19', 'CASP7'),
 ('MUC19', 'DSTN'),
 ('MUC19', 'AIF1L'),
 ('MUC19', 'SCN4A'),
 ('MUC19', 'HNF4G'),
 ('MUC19', 'MSMB'),
 ('MUC19', 'GDPD2'),
 ('MUC19', 'HOXB5'),
 ('MUC19', 'BCL2L14'),
 ('MUC19', 'ITGB6'),
 ('MUC19', 'OLFML3'),
 ('MUC19', 'FBN2'),
 ('MUC19', 'ANGPT1'),
 ('ATOH8', 'TLR6'),
 ('ATOH8', 'AGR3'),
 ('ATOH8', 'GDF15'),
 ('ATOH8', 'PGC'),
 ('ATOH8', 'HBB'),
 ('ATOH8', 'TGIF1'),
 ('ATOH8', 'LPAR1'),
 ('ATOH8', 'FAM107A'),
 ('ATOH8', 'PSMB8'),
 ('ATOH8', 'TRPM5'),
 ('ATOH8', 'CALML5'),
 ('ATOH8', 'EMP1'),
 ('ATOH8', 'CGA'),
 ('AT