In [4]:
import os
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from torch.amp import autocast
import pickle
import gc

# =========================================================
# CONFIG
# =========================================================
CONFIG = {
    "TEST_IDS": "/root/CAFA6data/cafa6-embeds/test_ids.txt",
    "TEST_EMBEDS": "/root/CAFA6data/cafa6-embeds/test_embeds.npy",

    "VOCAB_C95": "/root/CAFA6data/c95/vocab_C95_remove.csv",
    "VOCAB_C99": "/root/CAFA6data/c99/vocab_C99_remove.csv",
    "IA_FILE": "/root/CAFA6data/IA.tsv",

    "MODEL_C95": "/root/cafa6/c95/final_cafa6_model_c95.pth",
    "MODEL_C99": "/root/cafa6/c99/final_cafa6_model_c99.pth",

    "top_k": 500,
    "threshold": 0.001,
    "output_file": "submission_c95_c99_final.tsv",

    "device": "cuda",
    "batch_size": 1024,  
    "num_workers": 8,   

    "TAXON_PKL": "/root/cafa6/preprocessing/taxon_mapping_K_Species.pkl"
}

# =========================================================
# MODEL
# =========================================================
class WideProteinMLP_WithTaxon(nn.Module):
    def __init__(self, input_dim, num_classes, num_taxa, taxon_dim=64,
                 hidden_dims=[4096, 4096], dropout=0.4):
        super().__init__()
        self.seq_norm = nn.LayerNorm(input_dim)
        self.taxon_embedding = nn.Embedding(num_taxa, taxon_dim)
        self.taxon_norm = nn.LayerNorm(taxon_dim)

        layers = []
        prev = input_dim + taxon_dim
        for h in hidden_dims:
            layers += [nn.Linear(prev, h), nn.GELU(), nn.Dropout(dropout)]
            prev = h
        layers.append(nn.Linear(prev, num_classes))
        self.net = nn.Sequential(*layers)

    def forward(self, seq, tax):
        x = torch.cat([self.seq_norm(seq), self.taxon_norm(self.taxon_embedding(tax))], dim=1)
        return self.net(x)

# =========================================================
# DATASET
# =========================================================
class TestDataset(Dataset):
    def __init__(self, ids_path, embeds_path, taxon_pkl):
        with open(ids_path) as f:
            self.ids = [x.strip() for x in f]

        self.embeds = np.load(embeds_path, mmap_mode="r")

        with open(taxon_pkl, "rb") as f:
            tax = pickle.load(f)
        self.tax_map = tax["prot_to_taxon_idx"]
        self.num_taxa = tax["num_taxa_classes"]
        self.default_tax = self.num_taxa - 1


    def __len__(self): return len(self.ids)

    def __getitem__(self, i):
        pid = self.ids[i]
        feat = torch.tensor(self.embeds[i], dtype=torch.float32)

        tax = self.tax_map.get(pid, self.default_tax)
        return feat, torch.tensor(tax, dtype=torch.long), pid

# =========================================================
# LOAD MODEL
# =========================================================
def load_model(path, n_cls, num_taxa, dropout, device):
    print(f"Loading model: {path}")
    model = WideProteinMLP_WithTaxon(
        1280, n_cls, num_taxa, dropout=dropout
    ).to(device)

    model.load_state_dict(torch.load(path, map_location=device))
    model.eval()
    return model

# =========================================================
# MAIN
# =========================================================
def main():
    torch.backends.cudnn.benchmark = True 
    device = CONFIG["device"]
    
    print(" Loading Taxonomy Mapping...")
    with open(CONFIG["TAXON_PKL"], "rb") as f:
        tax_data = pickle.load(f)

    NUM_TAXA = tax_data["num_taxa_classes"]
    print(f" NUM_TAXA loaded from PKL = {NUM_TAXA}")
    
    print("Taxon stats:",
      "min =", min(tax_data["prot_to_taxon_idx"].values()),
      "max =", max(tax_data["prot_to_taxon_idx"].values()))

    print(" Loading Vocab & Maps...")
    df95 = pd.read_csv(CONFIG["VOCAB_C95"])
    df99 = pd.read_csv(CONFIG["VOCAB_C99"])
    
    terms95 = df95.term.tolist()
    terms99 = np.array(df99.term.tolist()) 

    # Map indices
    t99_idx = {t: i for i, t in enumerate(terms99)}
    src95, dst99 = zip(*[(i, t99_idx[t]) for i, t in enumerate(terms95) if t in t99_idx])
    src95 = torch.tensor(src95, device=device)
    dst99 = torch.tensor(dst99, device=device)

    # IA Vector
    ia_df = pd.read_csv(CONFIG["IA_FILE"], sep="\t", header=None, names=["term", "ia"])
    ia_map = dict(zip(ia_df.term, ia_df.ia))
    ia_vec = torch.tensor([ia_map.get(t, 0.0) for t in terms99], device=device)

    # --- Load Models ---
    m95 = load_model(CONFIG["MODEL_C95"], len(terms95), NUM_TAXA, 0.4, device)
    m99 = load_model(CONFIG["MODEL_C99"], len(terms99), NUM_TAXA, 0.25, device)

    # --- DataLoader Optimized ---
    dataset = TestDataset(CONFIG["TEST_IDS"], CONFIG["TEST_EMBEDS"], CONFIG["TAXON_PKL"])
    loader = DataLoader(
        dataset,
        batch_size=CONFIG["batch_size"],
        shuffle=False,
        num_workers=CONFIG["num_workers"],
        pin_memory=True,                  
        prefetch_factor=2
    )

    # Xóa file cũ nếu có
    if os.path.exists(CONFIG["output_file"]):
        os.remove(CONFIG["output_file"])

    print(f" Predicting {len(dataset)} proteins with Batch Size {CONFIG['batch_size']}...")
    
    # --- Inference Loop ---
    with torch.inference_mode():
        for feats, tax, pids in tqdm(loader):
            feats, tax = feats.to(device, non_blocking=True), tax.to(device, non_blocking=True)

            with autocast("cuda"):
               
                p99 = torch.sigmoid(m99(feats, tax))
                p95_raw = torch.sigmoid(m95(feats, tax))

                # Mapping C95 -> C99 Space
                p95 = torch.zeros_like(p99)
                p95[:, dst99] = p95_raw[:, src95]

                ia = ia_vec.unsqueeze(0)  # [1, C]

                # -------------------------
                # ZONE 1: IA < 2  (SAFE)
                # -------------------------
                mask_z1 = ia < 2.0
                prob_z1 = 0.95 * p95 + 0.05 * p99

                # -------------------------
                # ZONE 2: 2 <= IA < 4  (BATTLE)
                # -------------------------
                mask_z2 = (ia >= 2.0) & (ia < 4.0)

                better_c99_z2 = p99 > (p95 + 0.08)

                prob_z2 = torch.where(
                    better_c99_z2,
                    0.45 * p95 + 0.55 * p99,   
                    0.60 * p95 + 0.40 * p99    
                )

                # -------------------------
                # ZONE 3: IA >= 4  
                # -------------------------
                mask_z3 = ia >= 4.0

                # Base (khi C99 không vượt trội)
                prob_z3_base = 0.20 * p95 + 0.80 * p99

                # Strong signal từ C99 
                prob_z3_strong = 0.85 * p99 + 0.15 * p95

                better_c99_z3 = p99 > (p95 + 0.05)

                prob_z3 = torch.where(
                    better_c99_z3,
                    prob_z3_strong,
                    prob_z3_base
                )

                # -------------------------
                # COMBINE ALL ZONES
                # -------------------------
                p_final = (
                    prob_z1 * mask_z1 +
                    prob_z2 * mask_z2 +
                    prob_z3 * mask_z3
                )

            # --- POST-PROCESSING (VECTORIZED) ---
            p_final[p_final < CONFIG["threshold"]] = 0.0

            # 2. Top-K 
            top_vals, top_inds = torch.topk(p_final, CONFIG["top_k"], dim=1)

            # 3. Move to CPU 
            vals_np = top_vals.cpu().numpy() # [B, K]
            inds_np = top_inds.cpu().numpy() # [B, K]
            
            # --- OUTPUT WRITING ---
            
            vals_flat = vals_np.flatten()
            inds_flat = inds_np.flatten()
            

            pids_list = np.array(pids)
            pids_flat = np.repeat(pids_list, CONFIG["top_k"])
            
            valid_mask = vals_flat > 0
            
            if not np.any(valid_mask):
                continue 
            
            # Lọc dữ liệu hợp lệ
            final_pids = pids_flat[valid_mask]
            final_inds = inds_flat[valid_mask]
            final_vals = vals_flat[valid_mask]
            
            # Map Index -> Term String (Vectorized Lookup)
            final_terms = terms99[final_inds]
            
            df_batch = pd.DataFrame({
                'id': final_pids,
                'term': final_terms,
                'score': final_vals
            })
            
            df_batch.to_csv(
                CONFIG["output_file"], 
                sep='\t', 
                header=False, 
                index=False, 
                mode='a', 
                float_format='%.3f'
            )

    print(" SUBMISSION READY:", CONFIG["output_file"])

if __name__ == "__main__":
    main()

 Loading Taxonomy Mapping...
 NUM_TAXA loaded from PKL = 135
Taxon stats: min = 0 max = 134
 Loading Vocab & Maps...
Loading model: /root/cafa6/c95/final_cafa6_model_c95.pth


Loading model: /root/cafa6/c99/final_cafa6_model_c99.pth
 Predicting 224309 proteins with Batch Size 1024...


100%|██████████| 220/220 [04:12<00:00,  1.15s/it]

 SUBMISSION READY: submission_c95_c99_final.tsv



