In [None]:
import pandas as pd
import numpy as np
import os
import gc
from tqdm.auto import tqdm
from collections import defaultdict

# ==========================================
# 1. CẤU HÌNH HỆ THỐNG
# ==========================================
INPUT_DIR = r'E:\CAFA-6-Protein-Function-Prediction\input'
OUTPUT_DIR = r'E:\CAFA-6-Protein-Function-Prediction\output'

# Đường dẫn dữ liệu
GO_OBO_PATH = f'{INPUT_DIR}/cafa-6-protein-function-prediction/Train/go-basic.obo'
GOA_PATH = f'{INPUT_DIR}/protein-go-annotations/goa_uniprot_all.csv'

# Cấu hình trọng số
WEIGHTS = {
    'mlp': 0.15,
    'hybrid': 0.35,
    'resnet': 0.25,
    'knn': 0.25
}

print(f">>> Config Loaded.")
print(f"Weights: {WEIGHTS}")

# ==========================================
# 2. HÀM XỬ LÝ OBO (Giữ nguyên)
# ==========================================
def parse_obo(go_obo_path):
    print(f"[OBO] Parsing {go_obo_path}...")
    children = defaultdict(set)
    if not os.path.exists(go_obo_path):
        print(f"⚠️ Warning: OBO file not found at {go_obo_path}")
        return children

    with open(go_obo_path, "r") as f:
        cur_id = None
        for line in f:
            line = line.strip()
            if line == "[Term]":
                cur_id = None
            elif line.startswith("id: "):
                cur_id = line.split("id: ")[1].strip()
            elif line.startswith("is_a: "):
                pid = line.split()[1].strip()
                if cur_id:
                    children[pid].add(cur_id)
            elif line.startswith("relationship: part_of "):
                parts = line.split()
                if len(parts) >= 3:
                    pid = parts[2].strip()
                    if cur_id:
                        children[pid].add(cur_id)
    print(f"[OBO] Parsed {len(children)} parent nodes.")
    return children

# Load Global OBO Data
children_map = parse_obo(GO_OBO_PATH)

def get_descendants(go_id):
    """Tìm tất cả con cháu của một GO term"""
    desc = set()
    stack = [go_id]
    while stack:
        cur = stack.pop()
        for child in children_map.get(cur, []):
            if child not in desc:
                desc.add(child)
                stack.append(child)
    return desc

# ==========================================
# 3. HÀM XỬ LÝ NEGATIVE (TỐI ƯU HÓA MỚI)
# ==========================================
def get_negative_keys_optimized():
    """
    Đọc GOA theo chunk để tiết kiệm RAM và trả về DataFrame các cặp bị cấm.
    """
    if not os.path.exists(GOA_PATH):
        print("⚠️ GOA file not found. Skipping Negative Propagation.")
        return None
    
    print(f"[NEG] Loading GOA annotations efficiently (Chunking)...")
    
    neg_pairs = []
    # Đọc từng block 1 triệu dòng để tránh tràn RAM
    chunk_size = 1_000_000
    
    # Dùng tqdm để hiển thị tiến độ đọc
    reader = pd.read_csv(
        GOA_PATH, 
        usecols=['protein_id', 'go_term', 'qualifier'], 
        chunksize=chunk_size
    )
    
    for chunk in tqdm(reader, desc="Reading GOA chunks"):
        # Lọc ngay dòng có 'NOT'
        filtered = chunk[chunk['qualifier'].str.contains('NOT', na=False)]
        if not filtered.empty:
            neg_pairs.append(filtered[['protein_id', 'go_term']])
            
    if not neg_pairs:
        return None

    # Gộp các chunk lại
    neg_df = pd.concat(neg_pairs)
    print(f"[NEG] Found {len(neg_df)} direct negative annotations. Propagating...")
    
    # Group lại để lan truyền con cháu
    neg_map = neg_df.groupby('protein_id')['go_term'].apply(list).to_dict()
    
    final_neg_list = []
    
    # Lan truyền logic: Cha là NOT -> Con cháu cũng là NOT
    for pid, terms in tqdm(neg_map.items(), desc="Propagating Negatives"):
        all_neg_terms = set(terms)
        for t in terms:
            all_neg_terms |= get_descendants(t)
        
        # Lưu vào list để tạo DataFrame
        for t in all_neg_terms:
            final_neg_list.append((pid, t))
            
    # Tạo DataFrame chứa danh sách đen (Blacklist)
    blocked_df = pd.DataFrame(final_neg_list, columns=['id', 'term'])
    
    # Dọn dẹp RAM
    del neg_df, neg_map, neg_pairs
    gc.collect()
    
    print(f"[NEG] Final blocked pairs count: {len(blocked_df)}")
    return blocked_df

# ==========================================
# 4. HÀM LOAD & MAIN
# ==========================================
def load_submission_weighted(filename, weight):
    path = os.path.join(OUTPUT_DIR, filename)
    if not os.path.exists(path):
        print(f"❌ Missing: {filename} -> Skipping!")
        return None
    
    print(f"   + Loading {filename} (w={weight})...")
    df = pd.read_csv(path, sep='\t', names=['id', 'term', 'score'], header=None, dtype={'score': np.float32})
    df['score'] *= weight
    return df

def main():
    # --- BƯỚC 1: CHUẨN BỊ DỮ LIỆU LỌC (TỐI ƯU) ---
    blocked_df = get_negative_keys_optimized()
    
    # --- BƯỚC 2: LOAD & BLEND ---
    print("\n>>> [Step 2] Blending Models...")
    dfs = []
    
    df_mlp = load_submission_weighted('submission_mlp.tsv', WEIGHTS['mlp'])
    if df_mlp is not None: dfs.append(df_mlp)
        
    df_hybrid = load_submission_weighted('submission_hybrid.tsv', WEIGHTS['hybrid'])
    if df_hybrid is not None: dfs.append(df_hybrid)
        
    df_resnet = load_submission_weighted('submission_resnet.tsv', WEIGHTS['resnet'])
    if df_resnet is not None: dfs.append(df_resnet)
        
    df_knn = load_submission_weighted('submission_knn.tsv', WEIGHTS['knn'])
    if df_knn is not None: dfs.append(df_knn)
    
    if not dfs:
        raise ValueError("CRITICAL: No submission files found!")

    print("   Concatenating & Summing...")
    full_df = pd.concat(dfs, ignore_index=True)
    
    del dfs, df_mlp, df_hybrid, df_resnet, df_knn
    gc.collect()
    
    final_df = full_df.groupby(['id', 'term'], as_index=False)['score'].sum()
    print(f"   Blended Shape: {final_df.shape}")
    
    # --- BƯỚC 3: LỌC ÂM (MERGE ANTI JOIN - SIÊU TỐC) ---
    if blocked_df is not None and not blocked_df.empty:
        print("\n>>> [Step 3] Applying Negative Filtering (Optimized)...")
        initial_len = len(final_df)
        
        # Thêm cột cờ để nhận biết
        blocked_df['is_blocked'] = True
        
        # Left Join: final_df (Left) vs blocked_df (Right)
        merged = final_df.merge(blocked_df, on=['id', 'term'], how='left')
        
        # Chỉ giữ lại dòng nào có 'is_blocked' là NaN (tức là không nằm trong danh sách cấm)
        final_df = merged[merged['is_blocked'].isna()].drop(columns=['is_blocked'])
        
        # Giải phóng RAM
        del merged, blocked_df
        gc.collect()
        
        print(f"   Removed {initial_len - len(final_df)} invalid predictions.")
    
    # --- BƯỚC 4: GHI ĐÈ DƯƠNG (EXACT MATCH) ---
    exact_path = os.path.join(OUTPUT_DIR, 'submission_exact.tsv')
    if os.path.exists(exact_path) and os.path.getsize(exact_path) > 0:
        print("\n>>> [Step 4] Applying Exact Match Override...")
        df_exact = pd.read_csv(exact_path, sep='\t', names=['id', 'term', 'score_exact'], header=None)
        
        final_df = final_df.merge(df_exact, on=['id', 'term'], how='outer')
        final_df['score'] = final_df['score'].fillna(0)
        final_df['score_exact'] = final_df['score_exact'].fillna(0)
        
        final_df['score'] = np.maximum(final_df['score'], final_df['score_exact'])
        final_df = final_df[['id', 'term', 'score']]
        print(f"   Override applied.")
    else:
        print("\n>>> [Step 4] No Exact Match file found. Skipping override.")
        
    # --- BƯỚC 5: LƯU FILE ---
    print("\n>>> [Step 5] Saving Final Submission...")
    final_df['score'] = final_df['score'].round(3)
    final_df = final_df[final_df['score'] > 0.001]
    
    final_output = os.path.join(OUTPUT_DIR, 'submission.tsv')
    final_df.to_csv(final_output, sep='\t', index=False, header=False)
    
    print(f"✅ SUCCESS! Generated {final_output} with {len(final_df)} rows.")

if __name__ == "__main__":
    main()