In [None]:
go_direct = gp.read_gmt("/PATH/2024-09-08/hsa_ALL_BP_direct.gmt")
go_prop = gp.read_gmt("/PATH/hsa_ALL_BP_propagated.gmt")

file_path = "/PATH/goslim_agr.tsv"

go_slim = pd.read_csv(file_path, sep='\t')
go_slim["GO_ID"] = go_slim["?x"].astype(str).str.extract(r'(GO_\d+)')
go_slim["GO_ID"] = go_slim["GO_ID"].str.replace('_', ':')

slim_set = set(go_slim['GO_ID'].tolist()) & set(go_prop.keys())

go_direct_10 = [go for go, gene_list in go_direct.items() if len(gene_list) >= 20]
go_prop_20 = [go for go, gene_list in go_prop.items() if len(gene_list) >= 20]

go_use = list(set(go_direct_10).intersection(go_prop_20))

graph = obonet.read_obo("/PATH/go.obo")

slim_ancestors = {}

for slim in slim_set:
    if slim not in graph:
        print(slim)
        continue
    ancestors = nx.ancestors(graph, slim) 
    
    slim_ancestors[slim] = ancestors

go_to_slim = {}

for term in go_use:
    matching_keys = [key for key, values in slim_ancestors.items() if term in values]
    go_to_slim[term] = matching_keys

go_to_slim

filtered_go_to_slim = {term: matches for term, matches in go_to_slim.items() if matches}
filt_slim_to_go = defaultdict(list)

for term, matches in filtered_go_to_slim.items():
    for match in matches:
        filt_slim_to_go[match].append(term)

go_use = []  
filtered_filt_slim_to_go = {}  

for match, terms in filt_slim_to_go.items():
    available_terms = [term for term in terms if term not in go_use]
    
    selected_terms = available_terms[:3]
    go_use.extend(selected_terms)
    
    filtered_filt_slim_to_go[match] = selected_terms

go_prop_use = {term: genes for term, genes in go_prop.items() if term in go_use}
go_slim_use = {term: genes for term, genes in go_prop.items() if term in slim_set}

In [None]:
random.seed(42)

holdout_dict = {}
cv_fold1_dict = {}
cv_fold2_dict = {}
cv_fold3_dict = {}

for go_term in list(go_prop_use):
    pos_genes = set(go_prop_use[go_term])
    pos_count = len(pos_genes)

    # Get non-negative set (from the propagated slim)
    non_neg_list = []
    for slim in go_to_slim[go_term]:
        non_neg_list.extend(go_slim_use[slim])
    non_neg_list = set(non_neg_list)

    all_values = [value for values in go_slim_use.values() for value in values]
    all_values = set(all_values)

    # Negative genes (candidate pool)
    neg_genes = all_values - pos_genes - non_neg_list

    # Number of negatives we want to sample
    neg_needed = 10 * pos_count

    # Identify associated and other slim terms
    associated_slims = set(go_to_slim[go_term])
    all_slim_terms = set(go_slim_use.keys())
    other_slims = all_slim_terms - associated_slims

    chosen_negatives = set()

    if len(other_slims) > 0:
        neg_per_slim = neg_needed // len(other_slims)

        for slim_term in other_slims:
            candidate_genes = list(neg_genes.intersection(go_slim_use[slim_term]))
            alloc_count = min(neg_per_slim, len(candidate_genes))
            chosen = random.sample(candidate_genes, alloc_count) if alloc_count > 0 else []
            chosen_negatives.update(chosen)

        allocated_count = len(chosen_negatives)
        if allocated_count < neg_needed:
            remainder = neg_needed - allocated_count
            remaining_candidates = list(neg_genes - chosen_negatives)
            if remainder > len(remaining_candidates):
                remainder = len(remaining_candidates)
            if remainder > 0:
                chosen_negatives.update(random.sample(remaining_candidates, remainder))
    else:
        chosen_negatives = set(random.sample(neg_genes, min(neg_needed, len(neg_genes))))
    
    pos_list = list(pos_genes)
    neg_list = list(chosen_negatives)
    random.shuffle(pos_list)
    random.shuffle(neg_list)
    
    holdout_pos_count = max(1, int(0.2 * len(pos_list))) if len(pos_list) > 0 else 0
    holdout_neg_count = max(1, int(0.2 * len(neg_list))) if len(neg_list) > 0 else 0
    
    holdout_pos = pos_list[:holdout_pos_count]
    holdout_neg = neg_list[:holdout_neg_count]
    
    train_pos = pos_list[holdout_pos_count:]
    train_neg = neg_list[holdout_neg_count:]
    
    def split_into_folds(items, n_folds=3):
        fold_size = len(items) // n_folds
        folds = []
        start = 0
        for i in range(n_folds):
            extra = 1 if i < (len(items) % n_folds) else 0
            end = start + fold_size + extra
            folds.append(items[start:end])
            start = end
        return folds
    
    pos_folds = split_into_folds(train_pos, 3)
    neg_folds = split_into_folds(train_neg, 3)
        
    holdout_data = [{"gene": g, "result": 1} for g in holdout_pos] + \
                   [{"gene": g, "result": 0} for g in holdout_neg]
    holdout_df = pd.DataFrame(holdout_data)
    
    fold_dfs = []
    for i in range(3):
        fold_data = [{"gene": g, "result": 1} for g in pos_folds[i]] + \
                    [{"gene": g, "result": 0} for g in neg_folds[i]]
        fold_df = pd.DataFrame(fold_data)
        fold_dfs.append(fold_df)
    
    holdout_dict[go_term] = holdout_df
    cv_fold1_dict[go_term] = fold_dfs[0]
    cv_fold2_dict[go_term] = fold_dfs[1]
    cv_fold3_dict[go_term] = fold_dfs[2]

for go_term in holdout_dict.keys():
    holdout_pos = (holdout_dict[go_term]['result'] == 1).sum()
    holdout_neg = (holdout_dict[go_term]['result'] == 0).sum()
    
    fold1_pos = (cv_fold1_dict[go_term]['result'] == 1).sum()
    fold1_neg = (cv_fold1_dict[go_term]['result'] == 0).sum()
    
    fold2_pos = (cv_fold2_dict[go_term]['result'] == 1).sum()
    fold2_neg = (cv_fold2_dict[go_term]['result'] == 0).sum()
    
    fold3_pos = (cv_fold3_dict[go_term]['result'] == 1).sum()
    fold3_neg = (cv_fold3_dict[go_term]['result'] == 0).sum()
    
    print(
        f"{go_term}: "
        f"Holdout(Pos={holdout_pos}, Neg={holdout_neg}) | "
        f"Fold1(Pos={fold1_pos}, Neg={fold1_neg}) | "
        f"Fold2(Pos={fold2_pos}, Neg={fold2_neg}) | "
        f"Fold3(Pos={fold3_pos}, Neg={fold3_neg})"
    )


In [None]:
file_names = ["go_cv_fold1_dict_all.pkl", "go_cv_fold2_dict_all.pkl", "go_cv_fold3_dict_all.pkl", "go_holdout_dict_all.pkl"]
data_dicts = [cv_fold1_dict, cv_fold2_dict, cv_fold3_dict, holdout_dict]

for file_name, data_dict in zip(file_names, data_dicts):
    with open(file_name, 'wb') as f:
        pickle.dump(data_dict, f)

file_names