In [None]:
import os
import torch
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import gc
import torch.nn as nn
from collections import OrderedDict
from torch import amp

class WideProteinMLP(nn.Module):
    def __init__(self, input_dim, num_classes, hidden_dims=[2048, 4096], dropout=0.3):
        super().__init__()
        layers = []
        # Input Norm: Gi√∫p ·ªïn ƒë·ªãnh ƒë·∫ßu v√†o t·ª´ Embeddings
        layers.append(nn.LayerNorm(input_dim)) 
        
        prev = input_dim
        for h in hidden_dims:
            layers.append(nn.Linear(prev, h))
            layers.append(nn.GELU())   # Activation hi·ªán ƒë·∫°i
            layers.append(nn.Dropout(dropout))
            prev = h
            
        layers.append(nn.Linear(prev, num_classes))
        self.net = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.net(x)

# ============================================================================
# 1. C·∫§U H√åNH INFERENCE
# ============================================================================
CONFIG = {
    # --- Model Architecture (Ph·∫£i kh·ªõp v·ªõi l√∫c Train) ---
    'input_dim': 1280,             # ESM2-t33 embeddings
    'hidden_dims': [2048, 4096],   # Wide MLP layers
    'dropout': 0.3,                # Dropout 
    'num_classes': 6413,           # S·ªë l∆∞·ª£ng nh√£n c·ªßa b·ªô C95_remove 
    
    # --- System ---
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'batch_size': 64,            
    'num_workers': 2,
    
    # --- Paths ---
    'MODEL_PATH': "/kaggle/input/model-cafa6/best_model_wide_0.63.pth",  
    'EMBED_DIR': "/kaggle/input/cafa6-embeds", 
    'VOCAB_FILE': "/kaggle/input/c95-cafa6/vocab_C95_remove.csv",
    
    # --- Post-processing ---
    'min_score_threshold': 0.2,   
    'submission_limit': 150       
}

print(f"üöÄ C·∫•u h√¨nh Inference: Device={CONFIG['device']} | Threshold={CONFIG['min_score_threshold']}")

# ============================================================================
# 2. ƒê·ªäNH NGHƒ®A DATASET CHO T·∫¨P TEST
# ============================================================================
class TestDataset(Dataset):
    def __init__(self, embed_dir):
        with open(os.path.join(embed_dir, "test_ids.txt")) as f:
            self.ids = [line.strip() for line in f]
        self.embed_matrix = np.load(os.path.join(embed_dir, "test_embeds.npy"), mmap_mode="r")

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        feat = torch.from_numpy(self.embed_matrix[idx].copy()).float()
        return feat, self.ids[idx]

# ============================================================================
# 3. INFERENCE LOOP
# ============================================================================
def run_inference_gpu_stream():
    # 1. Setup
    print(f"Loading vocab from {CONFIG['VOCAB_FILE']}...")
    vocab_terms = pd.read_csv(CONFIG['VOCAB_FILE'])['term'].values
    
    ds = TestDataset(CONFIG['EMBED_DIR'])
    dl = DataLoader(ds, batch_size=CONFIG['batch_size'], shuffle=False, 
                    num_workers=4, pin_memory=True) # num_workers=4

    model = WideProteinMLP(CONFIG['input_dim'], CONFIG['num_classes'], 
                           CONFIG['hidden_dims'], CONFIG['dropout'])
    
    # 2. Load Weights
    print(f"Loading model from {CONFIG['MODEL_PATH']}...")
    ckpt = torch.load(CONFIG['MODEL_PATH'], map_location="cpu")
    sd = ckpt['model_state'] if isinstance(ckpt, dict) and 'model_state' in ckpt else ckpt
    clean_sd = {k.replace("module.", ""): v for k, v in sd.items()}
    model.load_state_dict(clean_sd)
    
    # 3. K√çCH HO·∫†T 2 GPU 
    model.to(CONFIG['device'])
    if torch.cuda.device_count() > 1:
        print(f"üî• K√≠ch ho·∫°t {torch.cuda.device_count()} GPUs!")
        model = nn.DataParallel(model)
    
    model.eval()

    # 4. Inference Loop
    out_file = "submission.tsv"
    print(f"Streaming predictions to {out_file}...")
    f = open(out_file, "w")
    
    count = 0
    with torch.no_grad():
        for features, prot_ids in tqdm(dl):
            features = features.to(CONFIG['device'])
            
            with amp.autocast('cuda'):
                logits = model(features)
                probs = torch.sigmoid(logits)
            
            probs = probs.float().cpu().numpy()

            for i, pid in enumerate(prot_ids):
                p = probs[i]
                
                idxs = np.where(p >= CONFIG['min_score_threshold'])[0]
                if len(idxs) == 0: continue
                
                scores = p[idxs]
                
                if len(idxs) > CONFIG['submission_limit']:
                    top_k_indices = np.argpartition(scores, -CONFIG['submission_limit'])[-CONFIG['submission_limit']:]
                    idxs = idxs[top_k_indices]
                    scores = scores[top_k_indices]
                
                # Ghi file
                batch_terms = vocab_terms[idxs]
                for term, sc in zip(batch_terms, scores):
                    f.write(f"{pid}\t{term}\t{sc:.3f}\n")
                    count += 1
            
            # D·ªçn d·∫πp th·ªß c√¥ng
            del features, logits, probs
            
    f.close()
    print(f"‚úÖ DONE! Created submission.tsv with {count:,} rows.")

if __name__ == "__main__":
    run_inference_gpu_stream()

üöÄ C·∫•u h√¨nh Inference: Device=cuda | Threshold=0.2
Loading vocab from /kaggle/input/c95-cafa6/vocab_C95_remove.csv...
Loading model from /kaggle/input/model-cafa6/best_model_wide_0.63.pth...
üî• K√≠ch ho·∫°t 2 GPUs!
Streaming predictions to submission.tsv...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3505/3505 [01:28<00:00, 39.46it/s]


‚úÖ DONE! Created submission.tsv with 32,761,135 rows.
