In [1]:
!pip install obonet networkx --quiet

import obonet
import networkx as nx
import pandas as pd
import numpy as np
import pickle
from collections import defaultdict, deque
from tqdm import tqdm

# ============================================================================
# 1. C·∫§U H√åNH ƒê∆Ø·ªúNG D·∫™N (CHECK K·ª∏)
# ============================================================================
CONFIG = {
    # File OBO g·ªëc
    'OBO_FILE': "/kaggle/input/cafa-6-protein-function-prediction/Train/go-basic.obo",
    
    'VOCAB_FILE': "/kaggle/input/c99-cafa6/vocab_C99_remove.csv",
    'TARGET_FILE': "/kaggle/input/c99-cafa6/train_targets_C99.pkl",
    
    # File ƒë·∫ßu ra
    'OUTPUT_PKL': "hierarchy_metadata.pkl"
}

def build_metadata_v2_final():
    print("üöÄ B·∫ÆT ƒê·∫¶U T·∫†O METADATA (FULL & CORRECTED)...")
    
    # --- A. LOAD D·ªÆ LI·ªÜU ---
    print("\n[1/4] Loading Graph & Vocab...")
    graph = obonet.read_obo(CONFIG['OBO_FILE'])
    vocab_df = pd.read_csv(CONFIG['VOCAB_FILE'])
    vocab_terms = vocab_df['term'].tolist()
    vocab_size = len(vocab_terms)
    term_to_idx = {t: i for i, t in enumerate(vocab_terms)}
    
    print(f"   - Graph nodes: {len(graph):,}")
    print(f"   - Vocab size: {vocab_size:,}")

    # --- B. T√çNH DEPTH NORM (BFS) ---
    print("\n[2/4] Calculating Depth Norm...")
    # (Ph·∫ßn n√†y code c≈© ƒë√£ ƒë√∫ng, gi·ªØ nguy√™n logic BFS)
    roots = ["GO:0008150", "GO:0003674", "GO:0005575"]
    depth_dict = {node: float('inf') for node in graph.nodes()}
    queue = deque()
    
    for root in roots:
        if root in graph:
            depth_dict[root] = 0
            queue.append(root)
            
    while queue:
        node = queue.popleft()
        d = depth_dict[node]
        # Predecessors = CON (Edge ƒëi t·ª´ con -> cha)
        for child in graph.predecessors(node):
            if depth_dict[child] > d + 1:
                depth_dict[child] = d + 1
                queue.append(child)
                
    depth_arr = np.array([depth_dict.get(t, 0) for t in vocab_terms], dtype=np.float32)
    depth_arr[depth_arr == float('inf')] = 0.0
    
    if depth_arr.max() > 0:
        depth_norm = depth_arr / depth_arr.max()
    else:
        depth_norm = depth_arr
    print(f"   - Max Depth: {depth_arr.max()}")

    # --- C. T√çNH PARENT MAP (TRANSITIVE CLOSURE) - [N√ÇNG C·∫§P] ---
    print("\n[3/4] Building Parent Map (Transitive - Fix Broken Chains)...")
    
    # 1. X√¢y d·ª±ng ƒë·ªì th·ªã 'is_a' to√†n v·∫πn t·ª´ file OBO g·ªëc
    # (ƒê·ªÉ t√¨m ƒë∆∞·ªùng ƒëi ngay c·∫£ khi node trung gian b·ªã c·∫Øt)
    full_isa_graph = nx.DiGraph()
    
    print("   - Building full is_a graph...")
    for node, data in graph.nodes(data=True):
        if "is_a" in data:
            for p_str in data["is_a"]:
                p_id = p_str.split(" ! ")[0]
                # Th√™m c·∫°nh Con -> Cha
                full_isa_graph.add_edge(node, p_id)
                
    # 2. Map Con -> T·∫•t c·∫£ T·ªï ti√™n (Ancestors) c√≥ trong Vocab
    child_to_parent = defaultdict(list)
    count_edges = 0
    vocab_set = set(vocab_terms) # ƒê·ªÉ tra c·ª©u nhanh
    
    for term, idx in tqdm(term_to_idx.items(), desc="Mapping Transitive"):
        if term not in full_isa_graph: continue
        
        # T√¨m t·∫•t c·∫£ t·ªï ti√™n trong ƒë·ªì th·ªã g·ªëc (bao g·ªìm c·∫£ cha, √¥ng, c·ª•...)
        # nx.descendants trong DiGraph(Con->Cha) s·∫Ω tr·∫£ v·ªÅ t·∫•t c·∫£ Ancestors
        try:
            all_ancestors = nx.descendants(full_isa_graph, term)
        except:
            continue # Ph√≤ng tr∆∞·ªùng h·ª£p l·ªói graph
            
        # Ch·ªâ gi·ªØ l·∫°i nh·ªØng t·ªï ti√™n C√ì M·∫∂T trong Vocab
        valid_ancestors = []
        for anc in all_ancestors:
            if anc in vocab_set:
                p_idx = term_to_idx[anc]
                valid_ancestors.append(p_idx)
                
        if valid_ancestors:
            child_to_parent[idx] = valid_ancestors
            count_edges += len(valid_ancestors)

    print(f"   - Mapped {count_edges:,} transitive relationships (Bridged gaps).")

    # --- D. T√çNH IC NORM (INFORMATION CONTENT) - [S·ª¨A QUAN TR·ªåNG] ---
    print("\n[4/4] Calculating Information Content (IC)...")
    
    # Load Targets
    with open(CONFIG['TARGET_FILE'], 'rb') as f:
        labels_dict = pickle.load(f)
    
    # ƒê·∫øm t·∫ßn su·∫•t
    term_counts = np.zeros(vocab_size, dtype=np.float32)
    total_samples = len(labels_dict)
    
    for indices in labels_dict.values():
        if len(indices) > 0:
            term_counts[indices] += 1
            
    # T√≠nh Frequency (C·ªông epsilon nh·ªè)
    freq = (term_counts + 1e-9) / total_samples
    
    # IC = -log(Freq)
    ic_values = -np.log(freq)
    
    # Normalize IC [0, 1]
    ic_max = ic_values.max()
    if ic_max > 0:
        ic_norm = ic_values / ic_max
    else:
        ic_norm = np.zeros_like(ic_values)
        
    print(f"   - IC Max Value: {ic_max:.4f}")
    print(f"   - IC Norm Shape: {ic_norm.shape}")

    # --- E. L∆ØU FILE ---
    print(f"\n>>> Saving to {CONFIG['OUTPUT_PKL']}...")
    
    save_data = {
        'depth_norm': depth_norm,
        'ic_norm': ic_norm,          
        'child_to_parent': dict(child_to_parent), # ƒê√£ fix logic is_a
        'term_to_idx': term_to_idx
    }
    
    with open(CONFIG['OUTPUT_PKL'], 'wb') as f:
        pickle.dump(save_data, f)
        
    print("‚úÖ DONE! Metadata V2 (Final) Created Successfully.")

if __name__ == "__main__":
    build_metadata_v2_final()

üöÄ B·∫ÆT ƒê·∫¶U T·∫†O METADATA (FULL & CORRECTED)...

[1/4] Loading Graph & Vocab...
   - Graph nodes: 40,122
   - Vocab size: 15,582

[2/4] Calculating Depth Norm...
   - Max Depth: 11.0

[3/4] Building Parent Map (Transitive - Fix Broken Chains)...
   - Building full is_a graph...


Mapping Transitive: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 15582/15582 [00:00<00:00, 29222.64it/s]


   - Mapped 130,171 transitive relationships (Bridged gaps).

[4/4] Calculating Information Content (IC)...
   - IC Max Value: 9.3735
   - IC Norm Shape: (15582,)

>>> Saving to hierarchy_metadata.pkl...
‚úÖ DONE! Metadata V2 (Final) Created Successfully.


In [2]:
# import pandas as pd
# import numpy as np
# import matplotlib.pyplot as plt

# # --- N√äN ƒê·ªîI T√äN FILE CHO ƒê√öNG ---
# # base_path   = "/kaggle/input/sample-submission/0.275_submission.tsv"
# # boost350    = "/kaggle/input/sample-submission/0.229_submission.tsv"
# # boost50     = "/kaggle/input/sample-submission/0.233_submission.tsv"
# # low     = "/kaggle/input/sample-submission/0.266_submission.tsv"

# submit = "/kaggle/working/submission.tsv"

# def load_sub(path):
#     return pd.read_csv(path, sep='\t', names=['protein','term','score'])

In [3]:
# import pandas as pd
# import numpy as np
# import matplotlib.pyplot as plt
# import seaborn as sns

# # File b·∫°n v·ª´a t·∫°o
# SUBMISSION_FILE = "/kaggle/input/sample-submission/0.275_submission.tsv"

# def analyze_threshold_impact():
#     print(f"üìÇ ƒêang ƒë·ªçc file {SUBMISSION_FILE}...")
#     # ƒê·ªçc file (c·ªôt: Protein, Term, Score)
#     df = pd.read_csv(SUBMISSION_FILE, sep='\t', names=['Protein', 'Term', 'Score'])
    
#     total_rows = len(df)
#     total_proteins = df['Protein'].nunique()
    
#     print(f"‚úÖ ƒê√£ t·∫£i xong!")
#     print(f"   - T·ªïng d√≤ng hi·ªán t·∫°i: {total_rows:,}")
#     print(f"   - T·ªïng Protein: {total_proteins:,}")
#     print(f"   - Trung b√¨nh nh√£n/Protein (Hi·ªán t·∫°i): {total_rows / total_proteins:.1f}")
    
#     # ‚úÖ TH·ªêNG K√ä MIN / MAX LABELS PER PROTEIN
#     labels_per_protein = df.groupby('Protein').size()
#     min_labels = labels_per_protein.min()
#     max_labels = labels_per_protein.max()
    
#     print(f"\nüìå TH·ªêNG K√ä LABELS / PROTEIN:")
#     print(f"   - Min labels / protein : {min_labels}")
#     print(f"   - Max labels / protein : {max_labels}")
    
#     # --- TH·ªêNG K√ä NG∆Ø·ª†NG GI·∫¢ L·∫¨P ---
#     print("\nüìä N·∫æU TƒÇNG THRESHOLD TH√å SAO?")
#     print("-" * 50)
#     print(f"{'Threshold':<10} | {'S·ªë d√≤ng c√≤n l·∫°i':<15} | {'Nh√£n/Protein':<15} | {'% Gi·ªØ l·∫°i':<10}")
#     print("-" * 50)
    
#     # C√°c m·ªëc threshold mu·ªën th·ª≠
#     thresholds = [0.01, 0.05, 0.08, 0.10, 0.12, 0.15, 0.35, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
    
#     for thr in thresholds:
#         # L·ªçc gi·∫£ l·∫≠p
#         filtered_count = (df['Score'] >= thr).sum()
#         avg_labels = filtered_count / total_proteins
#         percent_kept = (filtered_count / total_rows) * 100
        
#         print(f"{thr:<10.2f} | {filtered_count:<15,} | {avg_labels:<15.1f} | {percent_kept:<9.1f}%")
        
#     print("-" * 50)
    
#     # --- V·∫º BI·ªÇU ƒê·ªí PH√ÇN B·ªê SCORE ---
#     plt.figure(figsize=(10, 5))
#     sns.histplot(df['Score'], bins=100, color='teal', kde=False)
#     plt.axvline(0.08, color='red', linestyle='--', label='Current Threshold (0.08)')
#     plt.title(f'Ph√¢n b·ªë ƒëi·ªÉm s·ªë trong file {SUBMISSION_FILE}')
#     plt.xlabel('Score')
#     plt.ylabel('S·ªë l∆∞·ª£ng')
#     plt.legend()
#     plt.yscale('log') # D√πng log scale ƒë·ªÉ nh√¨n r√µ v√πng th·∫•p
#     plt.show()

# analyze_threshold_impact()

In [4]:
# import pandas as pd
# import numpy as np

# # ========= FILE PATH =========
# submission_path = "submission.tsv"
# ia_path = "/kaggle/input/cafa-6-protein-function-prediction/IA.tsv"

# # ========= ROOT TERMS (B·ªé KH·ªéI TH·ªêNG K√ä) =========
# ROOT_TERMS = {"GO:0003674", "GO:0005575", "GO:0008150"}

# # ========= LOAD IA =========
# df_ia = pd.read_csv(
#     ia_path,
#     sep="\t",
#     header=None,
#     names=["go_term", "ia"]
# )
# ia_map = dict(zip(df_ia["go_term"], df_ia["ia"]))
# print("‚úÖ S·ªë GO term c√≥ IA:", len(ia_map))

# # ========= IA BUCKET =========
# ia_labels = ["IA = 0", "0 < IA < 1", "1 ‚Üí 2", "2 ‚Üí 4", "‚â• 4"]

# # ========= SCORE BUCKET =========
# score_bins = [(i/10, (i+1)/10) for i in range(0, 10)]
# score_bins.append(("==1.0", "==1.0"))  # ‚úÖ bucket ƒë·∫∑c bi·ªát

# results = {b: {k: 0 for k in ia_labels} for b in score_bins}
# totals = {b: 0 for b in score_bins}

# missing_ia = 0
# skipped_root = 0

# # ========= STREAM SUBMISSION =========
# for chunk in pd.read_csv(
#     submission_path,
#     sep="\t",
#     header=None,
#     names=["protein_id", "go_term", "score"],
#     chunksize=2_000_000
# ):
#     for row in chunk.itertuples(index=False):
#         pid, go, sc = row

#         # ‚úÖ B·ªé ROOT
#         if go in ROOT_TERMS:
#             skipped_root += 1
#             continue

#         ia = ia_map.get(go, None)
#         if ia is None:
#             missing_ia += 1
#             continue

#         # ========= X√ÅC ƒê·ªäNH SCORE BUCKET =========
#         if sc == 1.0:
#             sb = ("==1.0", "==1.0")
#         else:
#             sb = None
#             for (s_lo, s_hi) in score_bins[:-1]:
#                 if s_lo <= sc < s_hi:
#                     sb = (s_lo, s_hi)
#                     break

#             if sb is None:
#                 continue

#         # ========= IA BUCKET LOGIC =========
#         if ia == 0:
#             results[sb]["IA = 0"] += 1
#         elif 0 < ia < 1:
#             results[sb]["0 < IA < 1"] += 1
#         elif 1 <= ia < 2:
#             results[sb]["1 ‚Üí 2"] += 1
#         elif 2 <= ia < 4:
#             results[sb]["2 ‚Üí 4"] += 1
#         else:
#             results[sb]["‚â• 4"] += 1

#         totals[sb] += 1

# # ========= IN K·∫æT QU·∫¢ =========
# for b in score_bins:
#     if b == ("==1.0", "==1.0"):
#         print(f"\nüî• PH√ÇN B·ªê IA (score == 1.0) [ƒê√É B·ªé ROOT]:")
#     else:
#         print(f"\nüìä PH√ÇN B·ªê IA ({b[0]:.1f} ‚Üí {b[1]:.1f}) [ƒê√É B·ªé ROOT]:")

#     total = totals[b]
#     if total == 0:
#         print("  (Tr·ªëng)")
#         continue

#     for k in ia_labels:
#         cnt = results[b][k]
#         pct = cnt / total * 100
#         print(f"{k:6s}: {cnt:>10,}  |  {pct:6.2f} %")

# print("\n‚ö†Ô∏è S·ªë d√≤ng b·ªã thi·∫øu IA:", missing_ia)
# print("üö´ S·ªë d√≤ng b·ªã lo·∫°i v√¨ l√† ROOT:", skipped_root)


In [5]:
# import pandas as pd
# import numpy as np
# from tqdm import tqdm
# import os
# from collections import defaultdict

# # =========================================================
# # C·∫§U H√åNH & THAM S·ªê
# # =========================================================
# SUB_FILE = "submission.tsv"
# IA_FILE  = "/kaggle/input/cafa-6-protein-function-prediction/IA.tsv"
# CHUNK_SIZE = 5_000_000 
# ROOT_TERMS = ["GO:0003674", "GO:0005575", "GO:0008150"]
# SCORE_BINS = np.linspace(0.0, 1.0, 11) # T·∫°o 11 m·ªëc: 0.0, 0.1, 0.2, ..., 1.0

# def analyze_ia_extremes_fixed():
#     # 1. Load submission (Ph√°t hi·ªán c·ªôt)
#     print("1. Reading submission to identify columns...")
#     df_head = pd.read_csv(SUB_FILE, sep=None, engine="python", header=None, nrows=1000)

#     go_col, score_col = None, None
#     for col in df_head.columns:
#         sample = df_head[col].astype(str)
#         if sample.str.startswith("GO:").any(): go_col = col
#         if pd.to_numeric(sample, errors="coerce").notna().mean() > 0.9: score_col = col

#     if go_col is None or score_col is None:
#         raise ValueError("Kh√¥ng t√¨m th·∫•y c·ªôt GO-term ho·∫∑c Score. Ki·ªÉm tra l·∫°i SUB_FILE.")

#     # 2. Load IA Map
#     ia_df = pd.read_csv(IA_FILE, sep="\t", header=None, names=["go", "ia"])
#     ia_map = dict(zip(ia_df.go, ia_df.ia))
#     print(f"2. IA Map Loaded ({len(ia_map):,} terms).")

#     # 3. PH√ÇN T√çCH STREAMING
#     ia_bin_stats = defaultdict(lambda: {'total_ia': 0.0, 'count': 0})

#     print("\n3. Starting Streaming Analysis (Binning Mean IA)...")
#     reader = pd.read_csv(SUB_FILE, sep=None, engine="python", header=None, chunksize=CHUNK_SIZE)

#     for chunk in tqdm(reader, desc="Processing Chunks"):
#         # ƒê·ªïi t√™n c·ªôt
#         df_chunk = chunk.rename(columns={go_col: "go", score_col: "score"})
        
#         # L·ªçc 3 nh√£n g·ªëc
#         df_chunk = df_chunk[~df_chunk['go'].isin(ROOT_TERMS)].copy()
        
#         # G√°n IA (Fill NaN = 0 cho an to√†n)
#         df_chunk["ia"] = df_chunk["go"].map(ia_map).fillna(0.0)
        
#         # L·ªçc ƒëi·ªÉm s·ªë h·ª£p l·ªá
#         df_chunk = df_chunk[(df_chunk['score'] >= 0) & (df_chunk['score'] <= 1.0)].copy()

#         # T√≠nh to√°n BINS
#         df_chunk['bin_index'] = np.digitize(df_chunk['score'], bins=SCORE_BINS) - 1 # ƒê·ªïi v·ªÅ index 0-9
        
#         # Nh√≥m theo index bin v√† t√≠nh t·ªïng IA v√† Count
#         chunk_stats = df_chunk.groupby('bin_index').agg(
#             total_ia=('ia', 'sum'),
#             count=('ia', 'count')
#         ).reset_index()

#         # C·∫≠p nh·∫≠t global stats
#         for _, row in chunk_stats.iterrows():
#             # [FIX L·ªñI] √âp ki·ªÉu v·ªÅ int ti√™u chu·∫©n
#             bin_idx = int(row['bin_index']) 
            
#             # Ki·ªÉm tra l·ªói ngo√†i bi√™n (do np.digitize ƒë√¥i khi tr·∫£ v·ªÅ 10)
#             if bin_idx >= len(SCORE_BINS) - 1 or bin_idx < 0:
#                 continue 
            
#             bin_min = SCORE_BINS[bin_idx]
#             bin_max = SCORE_BINS[bin_idx + 1]
            
#             key = (bin_min, bin_max)
#             ia_bin_stats[key]['total_ia'] += row['total_ia']
#             ia_bin_stats[key]['count'] += row['count']


#     # --- K·∫æT QU·∫¢ CU·ªêI C√ôNG ---
#     print("\n" + "="*70)
#     print(f"üìä PH√ÇN T√çCH MEAN IA THEO KHO·∫¢NG ƒêI·ªÇM S·ªê")
#     print("="*70)
    
#     total_all_counts = sum(s['count'] for s in ia_bin_stats.values())
    
#     print(f"{'Score Range':<15} | {'Mean IA':<10} | {'Total Count':<15} | {'Density (%)':<10}")
#     print("-" * 70)

#     for (bin_min, bin_max), stats in sorted(ia_bin_stats.items()):
#         count = stats['count']
#         mean_ia = stats['total_ia'] / count if count > 0 else 0.0
#         density = (count / total_all_counts) * 100 if total_all_counts > 0 else 0.0
        
#         range_str = f"[{bin_min:.2f} - {bin_max:.2f})"
        
#         print(f"{range_str:<15} | {mean_ia:<10.4f} | {count:<15,} | {density:<9.1f}%")

#     print("="*70)

# if __name__ == "__main__":
#     analyze_ia_extremes_fixed()

In [None]:
import os
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import pickle
from collections import OrderedDict
from torch import amp

# =========================================================
# C·∫§U H√åNH DUMP 
# =========================================================
CONFIG = {
    'MODEL1_PATH': "/kaggle/input/model-cafa6/best_model_wide_0.63.pth",     
    'MODEL2_PATH': "/kaggle/input/model-cafa6/model_v3_hmc_warmup.pth",  
    'MODEL3_PATH': "/kaggle/input/model-cafa6/best_model_c99_wide.pth",  
    
    'EMBED_DIR': "/kaggle/input/cafa6-embeds",
    'VOCAB_C95': "/kaggle/input/c95-cafa6/vocab_C95_remove.csv",
    'VOCAB_C99': "/kaggle/input/c99-cafa6/vocab_C99_remove.csv",
    'IA_FILE': "/kaggle/input/cafa-6-protein-function-prediction/IA.tsv",
    
    # Weights & Gating
    'BASE_W1': 0.6, 'BASE_W2': 0.1, 'BASE_W3': 0.3,
    'IA_HIGH_THRESHOLD': 2.0,
    'C99_CONFIDENCE_MIN': 0.12, 
    'C99_DELTA_MIN': 0.02,
    
    # DUMP CONFIG
    'DUMP_THRESHOLD': 0.01, # L·∫•y h·∫øt t√≠n hi·ªáu > 1% ƒë·ªÉ v·ªÅ nh√† l·ªçc sau
    
    'device': 'cuda', 'batch_size': 128, 'num_workers': 4
}

class WideProteinMLP(nn.Module):
    def __init__(self, input_dim, num_classes, hidden_dims=[2048, 4096], dropout=0.3):
        super().__init__()
        layers = [nn.LayerNorm(input_dim)]
        prev = input_dim
        for h in hidden_dims:
            layers += [nn.Linear(prev, h), nn.GELU(), nn.Dropout(dropout)]
            prev = h
        layers.append(nn.Linear(prev, num_classes))
        self.net = nn.Sequential(*layers)
    def forward(self, x): return self.net(x)

class TestDataset(Dataset):
    def __init__(self, embed_dir):
        with open(os.path.join(embed_dir, "test_ids.txt")) as f: self.ids = [line.strip() for line in f]
        self.embed_matrix = np.load(os.path.join(embed_dir, "test_embeds.npy"), mmap_mode="r")
    def __len__(self): return len(self.ids)
    def __getitem__(self, idx):
        return torch.from_numpy(self.embed_matrix[idx].copy()).float(), self.ids[idx]

def load_and_clean_model(path, num_classes, device):
    m = WideProteinMLP(1280, num_classes).to(device)
    try:
        ckpt = torch.load(path, map_location='cpu')
        sd = ckpt['model_state'] if isinstance(ckpt, dict) and 'model_state' in ckpt else ckpt
        new_sd = OrderedDict()
        for k, v in sd.items():
            if k.startswith('module.'): new_sd[k[7:]] = v
            else: new_sd[k] = v
        m.load_state_dict(new_sd)
        m.eval()
        return m
    except Exception as e:
        print(f"‚ùå Error: {e}")
        return None

def run_fast_dump_v2():
    print("üöÄ PHASE 1: GPU-ACCELERATED RAW DUMP...")
    device = CONFIG['device']
    
    # 1. Load Vocabs & Indexing
    df_c95 = pd.read_csv(CONFIG['VOCAB_C95'])
    df_c99 = pd.read_csv(CONFIG['VOCAB_C99'])
    c95_terms = df_c95['term'].tolist()
    c99_terms = np.array(df_c99['term'].tolist())
    
    # Mapping Indices
    c99_term_to_idx = {t: i for i, t in enumerate(c99_terms)}
    c95_to_c99_indices = []
    valid_c95_indices = []
    for i, t in enumerate(c95_terms):
        if t in c99_term_to_idx:
            c95_to_c99_indices.append(c99_term_to_idx[t])
            valid_c95_indices.append(i)
    c95_to_c99_indices = torch.tensor(c95_to_c99_indices, device=device, dtype=torch.long)
    valid_c95_indices = torch.tensor(valid_c95_indices, device=device, dtype=torch.long)
    
    # IA Vector
    try:
        ia_df = pd.read_csv(CONFIG['IA_FILE'], sep='\t', names=['term', 'ia'], header=None)
        ia_map = dict(zip(ia_df.term, ia_df.ia))
        ia_vector = np.array([ia_map.get(t, 0.0) for t in c99_terms], dtype=np.float32)
        ia_tensor = torch.tensor(ia_vector, device=device)
    except:
        ia_tensor = torch.zeros(len(c99_terms), device=device)

    # 2. Load Models 
    m1 = load_and_clean_model(CONFIG['MODEL1_PATH'], len(c95_terms), device)
    m2 = load_and_clean_model(CONFIG['MODEL2_PATH'], len(c95_terms), device)
    m3 = load_and_clean_model(CONFIG['MODEL3_PATH'], len(c99_terms), device)
    if torch.cuda.device_count() > 1:
        m1 = nn.DataParallel(m1); m2 = nn.DataParallel(m2); m3 = nn.DataParallel(m3)

    # 3. Inference Loop (GPU ACCELERATED)
    ds = TestDataset(CONFIG['EMBED_DIR'])
    dl = DataLoader(ds, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=4)
    
    f_out = open("raw_predictions.tsv", "w")
    
    DUMP_K = 500      
    DUMP_THR = CONFIG['DUMP_THRESHOLD'] 
    
    print(">>> Streaming & Dumping (GPU TopK Mode)...")
    with torch.no_grad():
        for features, prot_ids in tqdm(dl):
            features = features.to(device)
            with amp.autocast(device_type='cuda'):
                p1_raw = torch.sigmoid(m1(features))
                p2_raw = torch.sigmoid(m2(features))
                p3 = torch.sigmoid(m3(features))
            
            # --- BLENDING ON GPU ---
            final_probs = torch.zeros_like(p3)
            # Map C95 -> C99
            p1_mapped = torch.zeros_like(p3)
            p2_mapped = torch.zeros_like(p3)
            p1_mapped.index_add_(1, c95_to_c99_indices, p1_raw[:, valid_c95_indices])
            p2_mapped.index_add_(1, c95_to_c99_indices, p2_raw[:, valid_c95_indices])
            
            # Gating Logic
            high_ia = ia_tensor >= CONFIG['IA_HIGH_THRESHOLD']
            conf_c99 = p3 > CONFIG['C99_CONFIDENCE_MIN']
            better_c99 = p3 > (p1_mapped + CONFIG['C99_DELTA_MIN'])
            override_mask = high_ia & conf_c99 & better_c99
            
            base_prob = (CONFIG['BASE_W1'] * p1_mapped) + (CONFIG['BASE_W2'] * p2_mapped) + (CONFIG['BASE_W3'] * p3)
            override_prob = (0.2 * p1_mapped) + (0.8 * p3)
            final_probs = torch.where(override_mask, override_prob, base_prob)
            
            # --- [GPU ACCELERATION START] ---
            
            # 1. Mask c√°c gi√° tr·ªã d∆∞·ªõi threshold th√†nh -1 ƒë·ªÉ TopK kh√¥ng l·∫•y nh·∫ßm
            mask_low = final_probs < DUMP_THR
            final_probs.masked_fill_(mask_low, -1.0)
            
            # 2. L·∫•y Top K tr√™n GPU (C·ª±c nhanh)
            topk_vals, topk_inds = torch.topk(final_probs, k=DUMP_K, dim=1)
            
            # 3. Chuy·ªÉn k·∫øt qu·∫£ nh·ªè g·ªçn v·ªÅ CPU
            vals_np = topk_vals.float().cpu().numpy()
            inds_np = topk_inds.cpu().numpy()
            
            # --- [WRITING BUFFER] ---
            batch_lines = []
            for i, pid in enumerate(prot_ids):
                # Duy·ªát qua Top K c·ªßa protein i
                for j in range(DUMP_K):
                    score = vals_np[i, j]
                    idx = inds_np[i, j]
                    
                    # N·∫øu ƒëi·ªÉm < Threshold (do b·ªã fill -1 ho·∫∑c K qu√° l·ªõn), d·ª´ng l·∫°i
                    if score < DUMP_THR: 
                        continue
                        
                    term = c99_terms[idx]
                    batch_lines.append(f"{pid}\t{term}\t{score:.4f}\n")

            if batch_lines:
                f_out.write("".join(batch_lines))

    f_out.close()
    print("‚úÖ PHASE 1 DONE! Saved 'raw_predictions.tsv'.")

if __name__ == "__main__":
    run_fast_dump_v2()

In [None]:
import pandas as pd
import numpy as np
import pickle
from tqdm import tqdm
import os

# =========================================================
# C·∫§U H√åNH POST-PROCESS
# =========================================================
CONFIG_POST = {
    'RAW_FILE': "/kaggle/input/sample-submission/raw_predictions_500.tsv",
    'OUTPUT_FILE': "submission.tsv",
    'METADATA_PKL': "hierarchy_metadata.pkl",
    
    # --- LOGIC L·ªåC CU·ªêI C√ôNG ---
    'FINAL_THRESHOLD': 0.20, 
    'FINAL_CAP': 180,        # C·∫Øt ƒëu√¥i
    'MIN_LABELS': 1,         # Safety
    'ROOT_TERMS': ["GO:0003674", "GO:0005575", "GO:0008150"]
}

def run_offline_repair():
    print("üöÄ PHASE 2: OFFLINE REPAIR & FILTER...")
    
    # 1. Load Metadata
    print("   Loading Metadata...")
    with open(CONFIG_POST['METADATA_PKL'], 'rb') as f: meta = pickle.load(f)
    if 'child_to_parent' in meta: parent_map = meta['child_to_parent']
    else: parent_map = meta.get('child_to_parent_idx', {})
    
   
    vocab_df = pd.read_csv("/kaggle/input/c99-cafa6/vocab_C99_remove.csv") 
    term_to_idx = {t: i for i, t in enumerate(vocab_df['term'])}
    idx_to_term = {i: t for i, t in enumerate(vocab_df['term'])}
    
    # 2. Process File (Group by Protein)
    print("   Reading & Processing...")
    
    f_out = open(CONFIG_POST['OUTPUT_FILE'], "w")
    
    # ƒê·ªçc file raw
    reader = pd.read_csv(CONFIG_POST['RAW_FILE'], sep='\t', names=['PID', 'Term', 'Score'], 
                         chunksize=1000000)
    
    
    current_pid = None
    current_data = [] 
    
    count = 0
    
    # H√†m x·ª≠ l√Ω 1 protein
    def process_protein(pid, data_list):
        if not data_list: return 0
        
        # 1. Convert to Dict
        scores = {}
        for t_idx, s in data_list:
            scores[t_idx] = s
            
        # 2. Global Repair (BFS)
        queue = list(scores.keys())
        processed = set(queue)
        idx_ptr = 0
    
        
        while idx_ptr < len(queue):
            c_idx = queue[idx_ptr]; idx_ptr += 1
            c_score = scores[c_idx]
            
            for p_idx in parent_map.get(c_idx, []):
                p_prev = scores.get(p_idx, 0.0) # 0.0 n·∫øu ch∆∞a c√≥ (v√¨ raw ƒë√£ c·∫Øt 0.01)
                
                new_score = max(p_prev, c_score)
                if new_score > p_prev + 1e-6:
                    scores[p_idx] = new_score
                    if p_idx not in processed:
                        queue.append(p_idx); processed.add(p_idx)
        
        # 3. Filter & Cap
        sorted_items = sorted(scores.items(), key=lambda x: x[1], reverse=True)
        
        # Threshold
        filtered = [(idx, s) for idx, s in sorted_items if s >= CONFIG_POST['FINAL_THRESHOLD']]
        
        # Min Label Fallback
        if len(filtered) < CONFIG_POST['MIN_LABELS']:
            filtered = sorted_items[:CONFIG_POST['MIN_LABELS']]
            
        # Cap
        if len(filtered) > CONFIG_POST['FINAL_CAP']:
            filtered = filtered[:CONFIG_POST['FINAL_CAP']]
            
        # Write
        lines = []
        for idx, s in filtered:
            t_str = idx_to_term[idx]
            lines.append(f"{pid}\t{t_str}\t{s:.3f}\n")
        
        f_out.write("".join(lines))
        return len(lines)

    # LOOP CH√çNH
    for chunk in tqdm(reader, desc="Processing Chunks"):
        for row in chunk.itertuples(index=False):
            pid = row.PID
            term = row.Term
            score = row.Score
            
            # Map term string -> int index 
            if term not in term_to_idx: continue 
            t_idx = term_to_idx[term]
            
            if pid != current_pid:
                if current_pid is not None:
                    count += process_protein(current_pid, current_data)
                
                # Reset
                current_pid = pid
                current_data = []
            
            current_data.append((t_idx, score))
            
    # X·ª≠ l√Ω protein cu·ªëi c√πng
    if current_pid is not None:
        count += process_protein(current_pid, current_data)
        
    f_out.close()
    print(f"\n‚úÖ REPAIR DONE! Rows: {count:,}")

if __name__ == "__main__":
    run_offline_repair()

üöÄ PHASE 2: OFFLINE REPAIR & FILTER...
   Loading Metadata...
   Reading & Processing...


Processing Chunks: 113it [06:47,  3.60s/it]


‚úÖ REPAIR DONE! Rows: 39,543,118





In [8]:
# import pandas as pd
# from tqdm import tqdm
# import os
# import numpy as np

# # =========================================================
# # C·∫§U H√åNH KI·ªÇM TRA
# # =========================================================
# SUB_FILE = "submission.tsv"
# TEST_IDS_FILE = "/kaggle/input/cafa6-embeds/test_ids.txt" # Thay b·∫±ng ƒë∆∞·ªùng d·∫´n file ID th·ª±c t·∫ø
# CHUNK_SIZE = 5_000_000 

# # =========================================================
# # H√ÄM KI·ªÇM TRA
# # =========================================================

# def check_protein_coverage(submission_path, test_ids_path):
#     print(f"üî¨ B·∫Øt ƒë·∫ßu ki·ªÉm tra file: {submission_path}")
    
#     if not os.path.exists(submission_path):
#         print(f"‚ùå L·ªñI: Kh√¥ng t√¨m th·∫•y file submission t·∫°i: {submission_path}")
#         return False
        
#     # 1. Load danh s√°ch ID B·∫ÆT BU·ªòC (224,309 ID)
#     try:
#         with open(test_ids_path, 'r') as f:
#             required_pids = set(line.strip() for line in f)
        
#         TOTAL_REQUIRED = len(required_pids)
#         print(f"‚úÖ ƒê√£ load {TOTAL_REQUIRED:,} Protein ID b·∫Øt bu·ªôc.")
        
#     except Exception as e:
#         print(f"‚ùå L·ªñI: Kh√¥ng th·ªÉ ƒë·ªçc danh s√°ch ID test t·ª´ {test_ids_path}. L·ªói: {e}")
#         return False

#     # 2. ƒê·ªçc file submission theo chunks v√† thu th·∫≠p ID
#     found_pids = set()
#     total_submission_rows = 0
    
#     try:
#         # S·ª≠ d·ª•ng engine='python' v√¨ file TSV c√≥ th·ªÉ c√≥ l·ªói ƒë·ªãnh d·∫°ng nh·∫π
#         reader = pd.read_csv(submission_path, sep='\t', header=None, 
#                               usecols=[0], names=['Protein'], 
#                               chunksize=CHUNK_SIZE, engine='python')
        
#         # Qu√©t qua t·ª´ng chunk
#         for chunk in tqdm(reader, desc="Scanning PIDs in Submission"):
#             total_submission_rows += len(chunk)
#             # Th√™m c√°c ID duy nh·∫•t t·ª´ chunk v√†o t·∫≠p h·ª£p
#             found_pids.update(chunk['Protein'].astype(str).unique())
            
#     except Exception as e:
#         print(f"‚ùå L·ªñI: L·ªói khi ƒë·ªçc file submission theo chunk. L·ªói: {e}")
#         print("   => Vui l√≤ng ki·ªÉm tra l·∫°i c·∫•u tr√∫c file submission (TAB separator).")
#         return False

#     # 3. Ph√¢n t√≠ch k·∫øt qu·∫£
#     TOTAL_FOUND = len(found_pids)
    
#     # Ki·ªÉm tra s·ªë l∆∞·ª£ng Protein ID b·ªã thi·∫øu
#     missing_pids = required_pids - found_pids
#     TOTAL_MISSING = len(missing_pids)
    
#     print("\n" + "="*50)
#     print(f"üìä K·∫æT QU·∫¢ KI·ªÇM TRA PH·ª¶ S√ìNG PROTEIN")
#     print("="*50)
#     print(f"1. T·ªïng ID c·∫ßn thi·∫øt: {TOTAL_REQUIRED:,}")
#     print(f"2. T·ªïng ID t√¨m th·∫•y: {TOTAL_FOUND:,}")
#     print(f"3. T·ªïng d√≤ng Submission: {total_submission_rows:,}")
    
#     if TOTAL_MISSING == 0:
#         print("‚úÖ TH√ÄNH C√îNG: File submission ƒë√£ bao g·ªìm ƒê·∫¶Y ƒê·ª¶ 224,309 Protein ID.")
#     else:
#         print(f"‚ùå THI·∫æU {TOTAL_MISSING:,} Protein ID!")
#         # In ra 5 ID b·ªã thi·∫øu ƒë·∫ßu ti√™n l√†m v√≠ d·ª•
#         print(f"   => 5 ID b·ªã thi·∫øu l√†m v√≠ d·ª•: {list(missing_pids)[:5]}")
        
#     print("="*50)
#     return TOTAL_MISSING == 0

# if __name__ == "__main__":
#     # CH·∫†Y H√ÄM KI·ªÇM TRA (B·∫°n c·∫ßn thay ƒë·ªïi ƒë∆∞·ªùng d·∫´n)
#     # V√≠ d·ª• m·∫´u:
#     # check_protein_coverage("submission.tsv", "/kaggle/input/cafa6-embeds/test_ids.txt")
    
#     # Gi·∫£ s·ª≠ file ID test l√† 224,309 ID
#     # T·ª∞ CH·∫†Y H√ÄM N√ÄY SAU KHI THAY ƒê·ªîI ƒê∆Ø·ªúNG D·∫™N
    
    
#     check_protein_coverage(SUB_FILE, TEST_IDS_FILE)
    
    