In [None]:
import pandas as pd
import numpy as np
import random
from collections import defaultdict, Counter
import os
import glob
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import roc_auc_score, average_precision_score, precision_score
from sklearn.svm import SVC
import matplotlib.pyplot as plt
import seaborn as sns
import gseapy as gp
import obonet
import networkx as nx
import random
import math
import time
import pickle

# GO Task

In [None]:
go_direct = gp.read_gmt("/data/gmt/hsa_EXP_ALL_BP_direct.gmt")
go_prop = gp.read_gmt("/data/gmt/hsa_ALL_BP_propagated.gmt")

file_path = "/data/slim_sets/goslim_agr.tsv"

go_slim = pd.read_csv(file_path, sep='\t')
go_slim["GO_ID"] = go_slim["?x"].astype(str).str.extract(r'(GO_\d+)')
go_slim["GO_ID"] = go_slim["GO_ID"].str.replace('_', ':')

slim_set = set(go_slim['GO_ID'].tolist()) & set(go_prop.keys())

go_direct_20 = [go for go, gene_list in go_direct.items() if len(gene_list) >= 20]
go_prop_20 = [go for go, gene_list in go_prop.items() if len(gene_list) >= 20]

go_use = list(set(go_direct_20).intersection(go_prop_20))

graph = obonet.read_obo("/data/obo/go.obo")

slim_ancestors = {}

for slim in slim_set:
    if slim not in graph:
        print(slim)
        continue
    ancestors = nx.ancestors(graph, slim) 
    
    slim_ancestors[slim] = ancestors

go_to_slim = {}

for term in go_use:
    matching_keys = [key for key, values in slim_ancestors.items() if term in values]
    go_to_slim[term] = matching_keys

filtered_go_to_slim = {term: matches for term, matches in go_to_slim.items() if matches}
filt_slim_to_go = defaultdict(list)

for term, matches in filtered_go_to_slim.items():
    for match in matches:
        filt_slim_to_go[match].append(term)

go_use = []  
filtered_filt_slim_to_go = {}  

for match, terms in filt_slim_to_go.items():
    available_terms = [term for term in terms if term not in go_use]
    
    selected_terms = available_terms[:3]
    go_use.extend(selected_terms)
    
    filtered_filt_slim_to_go[match] = selected_terms

go_prop_use = {term: genes for term, genes in go_prop.items() if term in go_use}
go_slim_use = {term: genes for term, genes in go_prop.items() if term in slim_set}

In [None]:
random.seed(42)

holdout_dict = {}
cv_fold1_dict = {}
cv_fold2_dict = {}
cv_fold3_dict = {}

for go_term in list(go_prop_use):
    pos_genes = set(go_prop_use[go_term])
    pos_count = len(pos_genes)

    # Get non-negative set (from the propagated slim)
    non_neg_list = []
    for slim in go_to_slim[go_term]:
        non_neg_list.extend(go_slim_use[slim])
    non_neg_list = set(non_neg_list)

    all_values = [value for values in go_slim_use.values() for value in values]
    all_values = set(all_values)

    # Negative genes (candidate pool)
    neg_genes = all_values - pos_genes - non_neg_list

    # Number of negatives we want to sample
    neg_needed = 10 * pos_count

    associated_slims = set(go_to_slim[go_term])
    all_slim_terms = set(go_slim_use.keys())
    other_slims = all_slim_terms - associated_slims

    chosen_negatives = set()

    if len(other_slims) > 0:
        neg_per_slim = neg_needed // len(other_slims)

        # only pick negatives, evenly from other slim terms 
        for slim_term in other_slims:
            candidate_genes = list(neg_genes.intersection(go_slim_use[slim_term]))
            alloc_count = min(neg_per_slim, len(candidate_genes))
            chosen = random.sample(candidate_genes, alloc_count) if alloc_count > 0 else []
            chosen_negatives.update(chosen)

        # edge case handling 
        allocated_count = len(chosen_negatives)
        if allocated_count < neg_needed:
            remainder = neg_needed - allocated_count
            remaining_candidates = list(neg_genes - chosen_negatives)
            if remainder > len(remaining_candidates):
                remainder = len(remaining_candidates)
            if remainder > 0:
                chosen_negatives.update(random.sample(remaining_candidates, remainder))
    else:
        chosen_negatives = set(random.sample(neg_genes, min(neg_needed, len(neg_genes))))
    
    pos_list = list(pos_genes)
    neg_list = list(chosen_negatives)
    random.shuffle(pos_list)
    random.shuffle(neg_list)
    
    holdout_pos_count = max(1, int(0.2 * len(pos_list))) if len(pos_list) > 0 else 0
    holdout_neg_count = max(1, int(0.2 * len(neg_list))) if len(neg_list) > 0 else 0
    
    holdout_pos = pos_list[:holdout_pos_count]
    holdout_neg = neg_list[:holdout_neg_count]
    
    train_pos = pos_list[holdout_pos_count:]
    train_neg = neg_list[holdout_neg_count:]
    
    def split_into_folds(items, n_folds=3):
        fold_size = len(items) // n_folds
        folds = []
        start = 0
        for i in range(n_folds):
            extra = 1 if i < (len(items) % n_folds) else 0
            end = start + fold_size + extra
            folds.append(items[start:end])
            start = end
        return folds
    
    pos_folds = split_into_folds(train_pos, 3)
    neg_folds = split_into_folds(train_neg, 3)
        
    holdout_data = [{"gene": g, "result": 1} for g in holdout_pos] + \
                   [{"gene": g, "result": 0} for g in holdout_neg]
    holdout_df = pd.DataFrame(holdout_data)
    
    fold_dfs = []
    for i in range(3):
        fold_data = [{"gene": g, "result": 1} for g in pos_folds[i]] + \
                    [{"gene": g, "result": 0} for g in neg_folds[i]]
        fold_df = pd.DataFrame(fold_data)
        fold_dfs.append(fold_df)
    
    holdout_dict[go_term] = holdout_df
    cv_fold1_dict[go_term] = fold_dfs[0]
    cv_fold2_dict[go_term] = fold_dfs[1]
    cv_fold3_dict[go_term] = fold_dfs[2]

for go_term in holdout_dict.keys():
    holdout_pos = (holdout_dict[go_term]['result'] == 1).sum()
    holdout_neg = (holdout_dict[go_term]['result'] == 0).sum()
    
    fold1_pos = (cv_fold1_dict[go_term]['result'] == 1).sum()
    fold1_neg = (cv_fold1_dict[go_term]['result'] == 0).sum()
    
    fold2_pos = (cv_fold2_dict[go_term]['result'] == 1).sum()
    fold2_neg = (cv_fold2_dict[go_term]['result'] == 0).sum()
    
    fold3_pos = (cv_fold3_dict[go_term]['result'] == 1).sum()
    fold3_neg = (cv_fold3_dict[go_term]['result'] == 0).sum()
    
    print(
        f"{go_term}: "
        f"Holdout(Pos={holdout_pos}, Neg={holdout_neg}) | "
        f"Fold1(Pos={fold1_pos}, Neg={fold1_neg}) | "
        f"Fold2(Pos={fold2_pos}, Neg={fold2_neg}) | "
        f"Fold3(Pos={fold3_pos}, Neg={fold3_neg})"
    )


In [None]:
file_names = ["go_cv_fold1_dict_all.pkl", "go_cv_fold2_dict_all.pkl", "go_cv_fold3_dict_all.pkl", "go_holdout_dict_all.pkl"]
data_dicts = [cv_fold1_dict, cv_fold2_dict, cv_fold3_dict, holdout_dict]

for file_name, data_dict in zip(file_names, data_dicts):
    with open(file_name, 'wb') as f:
        pickle.dump(data_dict, f)

file_names

# OMIM Task

In [None]:
doid_direct = gp.read_gmt("/data/gmt/omim.20231030.direct.gmt")
doid_prop = gp.read_gmt("/data/gmt/omim.20231030.prop.gmt")

file_path = "/data/slim_sets/doid_agr_slim.tsv"

doid_slim = pd.read_csv(file_path, sep='\t')
slim_set = set(doid_slim['doid'].tolist())


doid_prop_20 = [doid for doid, gene_list in doid_prop.items() if len(gene_list) >= 20]
doid_use = list(set(doid_prop_20))

graph = obonet.read_obo("/data/obo/doid.obo")

slim_ancestors = {}

for slim in slim_set:
    if slim not in graph:
        print(slim)
        continue
    ancestors = nx.ancestors(graph, slim) 
    
    slim_ancestors[slim] = ancestors

doid_to_slim = {}

doid_to_slim = {}

for term in doid_use:
    matching_keys = [key for key, values in slim_ancestors.items() if term in values]
    doid_to_slim[term] = matching_keys

doid_to_slim

filtered_doid_to_slim = {term: matches for term, matches in doid_to_slim.items() if matches}
filt_slim_to_doid = defaultdict(list)

for term, matches in filtered_doid_to_slim.items():
    for match in matches:
        filt_slim_to_doid[match].append(term)

doid_use = []  
filtered_filt_slim_to_doid = {}  

for match, terms in filt_slim_to_doid.items():
    available_terms = [term for term in terms if term not in doid_use]
    
    selected_terms = available_terms 
    doid_use.extend(selected_terms)
    
    filtered_filt_slim_to_doid[match] = selected_terms

doid_prop_use = {term: genes for term, genes in doid_prop.items() if term in doid_use}
doid_slim_use = {term: genes for term, genes in doid_prop.items() if term in slim_set}

In [None]:
random.seed(42)

holdout_dict = {}
cv_fold1_dict = {}
cv_fold2_dict = {}
cv_fold3_dict = {}

for doid_term in list(doid_prop_use):
    pos_genes = set(doid_prop_use[doid_term])
    pos_count = len(pos_genes)

    # Get non-negative set (from the propagated slim)
    non_neg_list = []
    for slim in doid_to_slim[doid_term]:
        non_neg_list.extend(doid_slim_use[slim])
    non_neg_list = set(non_neg_list)

    all_values = [value for values in doid_slim_use.values() for value in values]
    all_values = set(all_values)

    # Negative genes (candidate pool)
    neg_genes = all_values - pos_genes - non_neg_list

    # Number of negatives we want to sample
    neg_needed = 10 * pos_count

    associated_slims = set(doid_to_slim[doid_term])
    all_slim_terms = set(doid_slim_use.keys())
    other_slims = all_slim_terms - associated_slims

    chosen_negatives = set()

    if len(other_slims) > 0:
        neg_per_slim = neg_needed // len(other_slims)

        for slim_term in other_slims:
            candidate_genes = list(neg_genes.intersection(doid_slim_use[slim_term]))
            alloc_count = min(neg_per_slim, len(candidate_genes))
            chosen = random.sample(candidate_genes, alloc_count) if alloc_count > 0 else []
            chosen_negatives.update(chosen)

        allocated_count = len(chosen_negatives)
        if allocated_count < neg_needed:
            remainder = neg_needed - allocated_count
            remaining_candidates = list(neg_genes - chosen_negatives)
            if remainder > len(remaining_candidates):
                remainder = len(remaining_candidates)
            if remainder > 0:
                chosen_negatives.update(random.sample(remaining_candidates, remainder))
    else:
        chosen_negatives = set(random.sample(neg_genes, min(neg_needed, len(neg_genes))))
    
    pos_list = list(pos_genes)
    neg_list = list(chosen_negatives)
    random.shuffle(pos_list)
    random.shuffle(neg_list)
    
    holdout_pos_count = max(1, int(0.2 * len(pos_list))) if len(pos_list) > 0 else 0
    holdout_neg_count = max(1, int(0.2 * len(neg_list))) if len(neg_list) > 0 else 0
    
    holdout_pos = pos_list[:holdout_pos_count]
    holdout_neg = neg_list[:holdout_neg_count]
    
    train_pos = pos_list[holdout_pos_count:]
    train_neg = neg_list[holdout_neg_count:]
    
    def split_into_folds(items, n_folds=3):
        fold_size = len(items) // n_folds
        folds = []
        start = 0
        for i in range(n_folds):
            extra = 1 if i < (len(items) % n_folds) else 0
            end = start + fold_size + extra
            folds.append(items[start:end])
            start = end
        return folds
    
    pos_folds = split_into_folds(train_pos, 3)
    neg_folds = split_into_folds(train_neg, 3)
        
    holdout_data = [{"gene": g, "result": 1} for g in holdout_pos] + \
                   [{"gene": g, "result": 0} for g in holdout_neg]
    holdout_df = pd.DataFrame(holdout_data)
    
    fold_dfs = []
    for i in range(3):
        fold_data = [{"gene": g, "result": 1} for g in pos_folds[i]] + \
                    [{"gene": g, "result": 0} for g in neg_folds[i]]
        fold_df = pd.DataFrame(fold_data)
        fold_dfs.append(fold_df)
    
    holdout_dict[doid_term] = holdout_df
    cv_fold1_dict[doid_term] = fold_dfs[0]
    cv_fold2_dict[doid_term] = fold_dfs[1]
    cv_fold3_dict[doid_term] = fold_dfs[2]

for doid_term in holdout_dict.keys():
    holdout_pos = (holdout_dict[doid_term]['result'] == 1).sum()
    holdout_neg = (holdout_dict[doid_term]['result'] == 0).sum()
    
    fold1_pos = (cv_fold1_dict[doid_term]['result'] == 1).sum()
    fold1_neg = (cv_fold1_dict[doid_term]['result'] == 0).sum()
    
    fold2_pos = (cv_fold2_dict[doid_term]['result'] == 1).sum()
    fold2_neg = (cv_fold2_dict[doid_term]['result'] == 0).sum()
    
    fold3_pos = (cv_fold3_dict[doid_term]['result'] == 1).sum()
    fold3_neg = (cv_fold3_dict[doid_term]['result'] == 0).sum()
    
    print(
        f"{doid_term}: "
        f"Holdout(Pos={holdout_pos}, Neg={holdout_neg}) | "
        f"Fold1(Pos={fold1_pos}, Neg={fold1_neg}) | "
        f"Fold2(Pos={fold2_pos}, Neg={fold2_neg}) | "
        f"Fold3(Pos={fold3_pos}, Neg={fold3_neg})"
    )


In [None]:
file_names = ["omim_cv_fold1_dict_all.pkl", "omim_cv_fold2_dict_all.pkl", "omim_cv_fold3_dict_all.pkl", "omim_holdout_dict_all.pkl"]
data_dicts = [cv_fold1_dict, cv_fold2_dict, cv_fold3_dict, holdout_dict]

for file_name, data_dict in zip(file_names, data_dicts):
    with open(file_name, 'wb') as f:
        pickle.dump(data_dict, f)

file_names

# Generalization Task

In [None]:
go_prop = gp.read_gmt("/data/gmt/GO_BP_EXP_propagated_10_since_240328.gmt")
go_older = gp.read_gmt("data/gmt/hsa_EXP_ALL_BP_direct.gmt")
with open("/go_generalization_folds/intersect_ref_genelist.txt", 'r') as f:
        reference_genes = [line.strip() for line in f]

# only check for intersecting set of genes 
reference_genes = set(reference_genes)

def filter_go_dict(go_dict, reference_genes):
    return {
        term: sorted(list(set(genes) & reference_genes))
        for term, genes in go_dict.items()
        if len(set(genes) & reference_genes) > 0
    }
    
go_older = filter_go_dict(go_older, reference_genes)
go_prop = filter_go_dict(go_prop, reference_genes)

In [None]:
random.seed(42)

holdout_dict        = {}
holdout_older_dict  = {}
cv_fold1_dict       = {}
cv_fold2_dict       = {}
cv_fold3_dict       = {}

for go_term in list(go_prop):
    pos_genes = set(go_older.get(go_term, []))
    pos_count = len(pos_genes)
    if pos_count == 0:
        continue

    all_values = {v for vals in go_older.values() for v in vals}
    neg_genes = all_values - pos_genes
    neg_genes = neg_genes - set(go_prop.get(go_term, []))

    neg_needed = 10 * pos_count
    chosen_negatives = set(random.sample(
        list(neg_genes),
        min(neg_needed, len(neg_genes))
    ))

    train_pos = list(pos_genes)
    train_neg = list(chosen_negatives)

    hpos = set(go_prop.get(go_term, [])) - pos_genes
    if len(hpos) <= 1:
        continue

    size_old = len(hpos)
    old_pos_sample = set(random.sample(
        train_pos,
        min(size_old, len(train_pos))
    ))
    old_neg_sample = set(random.sample(
        list(train_neg),
        min(10 * len(old_pos_sample), len(train_neg))
    ))

    train_pos = [g for g in train_pos if g not in old_pos_sample]
    train_neg = [g for g in train_neg if g not in old_neg_sample]

    holdout_dict[go_term] = pd.DataFrame(
        [{"gene": g, "result": 1} for g in hpos] +
        [{"gene": g, "result": 0} for g in old_neg_sample]
    ).sample(frac=1, random_state=42).reset_index(drop=True)

    holdout_older_dict[go_term] = pd.DataFrame(
        [{"gene": g, "result": 1} for g in old_pos_sample] +
        [{"gene": g, "result": 0} for g in old_neg_sample]
    ).sample(frac=1, random_state=42).reset_index(drop=True)

    def split_into_folds(items, n_folds=3):
        size = len(items) // n_folds
        folds = []
        start = 0
        for i in range(n_folds):
            extra = 1 if i < (len(items) % n_folds) else 0
            end = start + size + extra
            folds.append(items[start:end])
            start = end
        return folds

    pos_folds = split_into_folds(train_pos, 3)
    neg_folds = split_into_folds(train_neg, 3)

    cv_fold1_dict[go_term] = pd.DataFrame(
        [{"gene": g, "result": 1} for g in pos_folds[0]] +
        [{"gene": g, "result": 0} for g in neg_folds[0]]
    )
    cv_fold2_dict[go_term] = pd.DataFrame(
        [{"gene": g, "result": 1} for g in pos_folds[1]] +
        [{"gene": g, "result": 0} for g in neg_folds[1]]
    )
    cv_fold3_dict[go_term] = pd.DataFrame(
        [{"gene": g, "result": 1} for g in pos_folds[2]] +
        [{"gene": g, "result": 0} for g in neg_folds[2]]
    )

MIN_POS = 5
valid_terms = []

for go_term in holdout_dict:
    holdout_pos_count = int((holdout_dict[go_term]["result"] == 1).sum())
    f1_pos_count = int((cv_fold1_dict[go_term]["result"] == 1).sum())
    f2_pos_count = int((cv_fold2_dict[go_term]["result"] == 1).sum())
    f3_pos_count = int((cv_fold3_dict[go_term]["result"] == 1).sum())

    if (holdout_pos_count > MIN_POS and
        f1_pos_count   > MIN_POS and
        f2_pos_count   > MIN_POS and
        f3_pos_count   > MIN_POS):
        valid_terms.append(go_term)

holdout_dict_filtered       = {gt: holdout_dict[gt]       for gt in valid_terms}
holdout_older_dict_filtered = {gt: holdout_older_dict[gt] for gt in valid_terms}
cv_fold1_dict_filtered      = {gt: cv_fold1_dict[gt]      for gt in valid_terms}
cv_fold2_dict_filtered      = {gt: cv_fold2_dict[gt]      for gt in valid_terms}
cv_fold3_dict_filtered      = {gt: cv_fold3_dict[gt]      for gt in valid_terms}

filtered_dicts = [
    cv_fold1_dict_filtered,
    cv_fold2_dict_filtered,
    cv_fold3_dict_filtered,
    holdout_dict_filtered,
    holdout_older_dict_filtered,
]

print(f"Total GO terms before filtering         : {len(holdout_dict)}")
print(f"GO terms remaining after > {MIN_POS} filter: {len(valid_terms)}")


for go_term in filtered_dicts[3]: 
    hp = (filtered_dicts[3][go_term]['result'] == 1).sum()
    hn = (filtered_dicts[3][go_term]['result'] == 0).sum()
    f1p = (filtered_dicts[0][go_term]['result'] == 1).sum()
    f1n = (filtered_dicts[0][go_term]['result'] == 0).sum()
    f2p = (filtered_dicts[1][go_term]['result'] == 1).sum()
    f2n = (filtered_dicts[1][go_term]['result'] == 0).sum()
    f3p = (filtered_dicts[2][go_term]['result'] == 1).sum()
    f3n = (filtered_dicts[2][go_term]['result'] == 0).sum()
    ohp = (filtered_dicts[4][go_term]['result'] == 1).sum()
    ohn = (filtered_dicts[4][go_term]['result'] == 0).sum()

    print(
        f"{go_term}: "
        f"Holdout(Pos={hp}, Neg={hn}) | "
        f"Fold1(Pos={f1p}, Neg={f1n}) | "
        f"Fold2(Pos={f2p}, Neg={f2n}) | "
        f"Fold3(Pos={f3p}, Neg={f3n}) | "
        f"OlderHoldout(Pos={ohp}, Neg={ohn})"
    )

In [None]:
file_names = [
    "go_generalization_folds/go_cv_fold1_dict_all.pkl",
    "go_generalization_folds/go_cv_fold2_dict_all.pkl",
    "go_generalization_folds/go_cv_fold3_dict_all.pkl",
    "go_generalization_folds/go_holdout_dict_all.pkl",
    "go_generalization_folds/go_holdout_older_dict_all.pkl",
]

for fn, dd in zip(file_names, filtered_dicts):
    data_dicts = {gt: df for gt, df in dd.items()}
    with open(fn, "wb") as f:
        pickle.dump(data_dicts, f)