In [None]:
# =========================================================
# UTILS
# =========================================================
import pandas as pd
import numpy as np
import pickle
from tqdm import tqdm
import os
import gc
from collections import defaultdict

# =========================================================
# CONFIGURATION
# =========================================================
CONFIG_POST = {
    # --- INPUT FILES ---
    'RAW_FILE': "submission_c95_c99_final.tsv",                                     
    'VOCAB_FILE': "/kaggle/input/c99-cafa6/vocab_C99_remove.csv",            
    
    # --- EXTERNAL DATA ---
    'GOA_FILE': "/kaggle/input/protein-go-annotations/goa_uniprot_all.csv",  
    'OBO_FILE': "/kaggle/input/cafa-6-protein-function-prediction/Train/go-basic.obo", 
    'IA_FILE': "/kaggle/input/cafa-6-protein-function-prediction/IA.tsv", 
    
    'OUTPUT_FILE': "submission.tsv",
    
    # --- PARAMETERS ---
    'FINAL_THRESHOLD': 0.35,  
    'FINAL_CAP': 250,         
    
    'PROP_MIN_SCORE': 0.50,   
}

# =========================================================
# UTILS
# =========================================================
def parse_obo_parents(obo_path):
    print(f"Parsing OBO from {obo_path}...")
    term_to_parents = defaultdict(set)
    if not os.path.exists(obo_path): return term_to_parents
    with open(obo_path, "r") as f:
        cur_id = None
        for line in f:
            line = line.strip()
            if line == "[Term]": cur_id = None
            elif line.startswith("id: "): cur_id = line.split("id: ")[1].strip()
            elif line.startswith("is_a: ") and cur_id: term_to_parents[cur_id].add(line.split()[1].strip())
            elif line.startswith("relationship: part_of ") and cur_id:
                parts = line.split()
                if len(parts) >= 3: term_to_parents[cur_id].add(parts[2].strip())
    return term_to_parents

def get_descendants(term, children_map):
    out = set()
    stack = [term]
    while stack:
        t = stack.pop()
        for c in children_map.get(t, []):
            if c not in out:
                out.add(c)
                stack.append(c)
    return out

def parse_obo_children(obo_path):
    print(f"Parsing OBO from {obo_path}...")
    children_map = defaultdict(set)

    with open(obo_path, "r") as f:
        cur_id = None
        for line in f:
            line = line.strip()
            if line == "[Term]":
                cur_id = None
            elif line.startswith("id: "):
                cur_id = line.split("id: ")[1].strip()
            elif line.startswith("is_a: ") and cur_id:
                parent = line.split()[1].strip()
                children_map[parent].add(cur_id)
            elif line.startswith("relationship: part_of ") and cur_id:
                parent = line.split()[2].strip()
                children_map[parent].add(cur_id)

    print(f"Parsed {len(children_map)} parents with children.")
    return children_map


def load_uniprot_data(goa_path, children_map):
    print(f"\n>>> Loading GOA Database from {goa_path}...")
    df = pd.read_csv(goa_path, usecols=['protein_id','go_term','qualifier'], dtype=str)
    print(f"Loaded {len(df)} annotations.")

    neg_df = df[df['qualifier'].str.contains('NOT', na=False)]
    neg_map = neg_df.groupby('protein_id')['go_term'].apply(list).to_dict()

    neg_keys = set()
    print("Propagating Negatives...")
    for pid, terms in tqdm(neg_map.items()):
        for t in terms:
            neg_keys.add(f"{pid}_{t}")
            for c in get_descendants(t, children_map):
                neg_keys.add(f"{pid}_{c}")

    print(f"Identified {len(neg_keys)} negative Protein-Term pairs.")

    pos_df = df[~df['qualifier'].str.contains('NOT', na=False)]
    pos_map = defaultdict(set)
    for pid, term in zip(pos_df['protein_id'], pos_df['go_term']):
        pos_map[pid].add(term)

    print(f"Identified {sum(len(v) for v in pos_map.values())} positive pairs to inject (Score = 1.0).")

    del df, neg_df, pos_df
    gc.collect()
    return pos_map, neg_keys
# =========================================================
# CORE PROCESSING
# =========================================================
def run_offline_repair_safe():
    print(" STARTING SAFE OFFLINE REPAIR PIPELINE...")
    
    # --- A. LOAD RESOURCES ---
    print(" Loading Vocab & IA...")
    vocab_df = pd.read_csv(CONFIG_POST['VOCAB_FILE'])
    term_to_idx = {t: i for i, t in enumerate(vocab_df['term'])}
    idx_to_term = {i: t for i, t in enumerate(vocab_df['term'])}
    
    # Load IA Map (Cho Logic 2)
    ia_df = pd.read_csv(CONFIG_POST['IA_FILE'], sep='\t', header=None, names=['term', 'ia'])
    term_to_ia = dict(zip(ia_df.term, ia_df.ia))
    
    # Load Parents
    obo_parents = parse_obo_parents(CONFIG_POST['OBO_FILE'])
    parent_map_idx = defaultdict(list)
    for t_str, parents in obo_parents.items():
        if t_str in term_to_idx:
            c_idx = term_to_idx[t_str]
            for p_str in parents:
                if p_str in term_to_idx: parent_map_idx[c_idx].append(term_to_idx[p_str])
    
    children_map = parse_obo_children(CONFIG_POST['OBO_FILE'])
    gt_pos_map, gt_neg_keys = load_uniprot_data(CONFIG_POST['GOA_FILE'], children_map)
    
    print("Loading KNN Rescue Data...")
    knn_rescue_map = defaultdict(list)
    if os.path.exists(CONFIG_POST['KNN_FILE']):
        knn_df = pd.read_csv(CONFIG_POST['KNN_FILE'], sep='\t', names=['PID', 'Term', 'Score'])
        for pid, term, score in tqdm(zip(knn_df.PID, knn_df.Term, knn_df.Score), total=len(knn_df)):
            if term in term_to_idx:
                knn_rescue_map[pid].append((term_to_idx[term], float(score)))
        del knn_df; gc.collect()

    # --- B. PROCESSING ---
    print(" Processing Proteins...")
    f_out = open(CONFIG_POST['OUTPUT_FILE'], "w")
    reader = pd.read_csv(CONFIG_POST['RAW_FILE'], sep='\t', names=['PID', 'Term', 'Score'], 
                         chunksize=1000000, dtype={'PID': str, 'Term': str, 'Score': float})
    
    current_pid = None; current_scores = {}; total_written = 0
    processed_pids = set()

    # === GLOBAL STATS ===
    STAT_NEG_REMOVED = 0
    STAT_POS_BOOSTED = 0
    STAT_POS_ADDED   = 0
    
    def process_and_write(pid, scores_dict):
        nonlocal STAT_NEG_REMOVED, STAT_POS_BOOSTED, STAT_POS_ADDED

        original_keys = set(scores_dict.keys())
        # (2) FILTER KNN HEAD & (1) NO OVERWRITE
        if pid in knn_rescue_map:
            for t_idx, k_score in knn_rescue_map[pid]:
        
                # KhÃ´ng ghi Ä‘Ã¨ DL
                if t_idx in scores_dict:
                    continue
        
                t_str = idx_to_term[t_idx]
                ia = term_to_ia.get(t_str, 0.0)
        
                #  CHá»ˆ Cá»¨U IA Ráº¤T CAO
                if ia < 4.0:
                    continue
        
                #  CHá»ˆ NHáº¬N SCORE KNN CHáº®C
                if k_score < 0.25:
                    continue
        
                # Clamp cá»©ng
                scores_dict[t_idx] = 0.4
    
        # --- UNIPROT INJECTION ---
        if pid in gt_pos_map:
            for term in gt_pos_map[pid]:
                if term in term_to_idx:
                    t_idx = term_to_idx[term]
                    if t_idx in scores_dict:
                        if scores_dict[t_idx] < 1.0:
                            STAT_POS_BOOSTED += 1
                    else:
                        STAT_POS_ADDED += 1
                    scores_dict[t_idx] = max(scores_dict.get(t_idx, 0.0), 1.0)

        # NEGATIVE FILTER
        # --- NEGATIVE FILTER ---
        keys_to_remove = []
        for k in scores_dict:
            if f"{pid}_{idx_to_term[k]}" in gt_neg_keys:
                keys_to_remove.append(k)
    
        STAT_NEG_REMOVED += len(keys_to_remove)
        for k in keys_to_remove:
            del scores_dict[k]

        # FINAL FILTER & CAP
        final_items = [(t, s) for t, s in scores_dict.items() if s >= CONFIG_POST['FINAL_THRESHOLD']]
        final_items.sort(key=lambda x: x[1], reverse=True)
        if len(final_items) > CONFIG_POST['FINAL_CAP']: final_items = final_items[:CONFIG_POST['FINAL_CAP']]
            
        lines = [f"{pid}\t{idx_to_term[t]}\t{s:.3f}\n" for t, s in final_items]
        f_out.write("".join(lines))
        return len(lines)

    # --- MAIN LOOP ---
    for chunk in tqdm(reader, desc="Processing Batches"):
        chunk_pids = chunk['PID'].values; chunk_terms = chunk['Term'].values; chunk_scores = chunk['Score'].values
        for i in range(len(chunk)):
            p = chunk_pids[i]; t_str = chunk_terms[i]; s = chunk_scores[i]
            if t_str not in term_to_idx: continue
            
            if p != current_pid:
                if current_pid: 
                    total_written += process_and_write(current_pid, current_scores)
                    processed_pids.add(current_pid)
                current_pid = p; current_scores = {}
            current_scores[term_to_idx[t_str]] = s
            
    if current_pid: 
        total_written += process_and_write(current_pid, current_scores)
        processed_pids.add(current_pid)

    # --- RESCUE MISSING ---
    print("ðŸ”„ Rescuing missing proteins...")
    missing_pids = (set(knn_rescue_map.keys()) | set(gt_pos_map.keys())) - processed_pids
    for pid in tqdm(missing_pids): total_written += process_and_write(pid, {})
        
    f_out.close()
    print("\n>>> Applying Logic...")
    print(f"1. Removing Negatives...")
    print(f"   Removed {STAT_NEG_REMOVED} negative predictions.")
    
    print("2. Injecting Ground Truth...")
    print(f"   Boosted {STAT_POS_BOOSTED} existing predictions to 1.0")
    print(f"   Adding {STAT_POS_ADDED} completely new ground-truth rows...")
    print(f"\n DONE! Predictions: {total_written:,}")

if __name__ == "__main__":
    run_offline_repair_safe()

 STARTING SAFE OFFLINE REPAIR PIPELINE...
 Loading Vocab & IA...
Parsing OBO from /kaggle/input/cafa-6-protein-function-prediction/Train/go-basic.obo...
Parsing OBO from /kaggle/input/cafa-6-protein-function-prediction/Train/go-basic.obo...
Parsed 16788 parents with children.

>>> Loading GOA Database from /kaggle/input/protein-go-annotations/goa_uniprot_all.csv...
Loaded 2583077 annotations.
Propagating Negatives...


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 3863/3863 [00:00<00:00, 16451.29it/s]


Identified 238553 negative Protein-Term pairs.
Identified 2464255 positive pairs to inject (Score = 1.0).
Loading KNN Rescue Data...


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 1360373/1360373 [00:02<00:00, 510077.70it/s]


 Processing Proteins...


Processing Batches: 113it [03:52,  2.06s/it]


ðŸ”„ Rescuing missing proteins...


0it [00:00, ?it/s]



>>> Applying Logic...
1. Removing Negatives...
   Removed 6875 negative predictions.
2. Injecting Ground Truth...
   Boosted 1931323 existing predictions to 1.0
   Adding 382835 completely new ground-truth rows...

 DONE! Predictions: 33,521,389
