In [None]:
import os, sys, math, time, pickle, gc

import numpy as np

import pandas as pd

from tqdm import tqdm

import torch

import torch.nn as nn

import torch.optim as optim

from torch.utils.data import Dataset, DataLoader

from torch.amp import autocast, GradScaler


# ============================================================================

# 1. CONFIG

# ============================================================================

CONFIG = {

    # --- Input Embeddings ---
    "EMBED_DIR": "/kaggle/input/cafa6-embeds",

    # --- Input Labels & Metadata (Dataset c95-cafa6) ---
    "WORK_DIR": "/kaggle/working",
    
    "LABEL_DIR": "/kaggle/input/c95-cafa6",

    'VOCAB_FILE': "vocab_C95_remove.csv",
    'TARGET_FILE': "train_targets_C95.pkl",
    'TRAIN_IDS': "train_ids_C95_split.npy",
    'VAL_IDS': "val_ids_C95_split.npy",
    
    "IA_FILE": "/kaggle/input/cafa-6-protein-function-prediction/IA.tsv",

    'TAXON_PKL': "/kaggle/input/cafa6-embeds/taxon_mapping_K140.pkl",

    # --- Model Params ---
    "input_dim": 1280,
    "hidden_dims": [2048, 4096],
    "dropout": 0.3,
    'taxon_embed_dim': 64,
    
    "batch_size": 16,
    "lr": 2e-4,
    "weight_decay": 1e-4,
    "epochs": 25,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
}

# ============================================================================
# 2. MODEL M·ªöI (C√ì TAXONOMY)
# ============================================================================

class WideProteinMLP_WithTaxon(nn.Module):
    def __init__(self, input_dim, num_classes, num_taxa, taxon_dim=64, hidden_dims=[2048, 4096], dropout=0.3):
        super().__init__()
        
        # 1. Nh√°nh Protein (ESM-2) 
        self.bn_input = nn.LayerNorm(input_dim)
        
        # 2. Nh√°nh Taxonomy - Embedding h·ªçc ƒë∆∞·ª£c
        self.taxon_embedding = nn.Embedding(num_taxa, taxon_dim)
        self.taxon_norm = nn.LayerNorm(taxon_dim)  

        self.unk_idx = num_taxa - 1

        with torch.no_grad():
            self.taxon_embedding.weight[self.unk_idx].zero_()
        
        self.register_buffer(
            "unk_fixed_vector",
            self.taxon_embedding.weight[self.unk_idx].clone()
        )
        
        # 3. T·ªïng h·ª£p
        combined_dim = input_dim + taxon_dim
        
        layers = []
        prev = combined_dim
        for h in hidden_dims:
            layers.append(nn.Linear(prev, h))
            layers.append(nn.GELU())
            layers.append(nn.Dropout(dropout))
            prev = h
            
        layers.append(nn.Linear(prev, num_classes))
        self.net = nn.Sequential(*layers)
        
    def forward(self, x_seq, x_tax):
        # x_seq: [batch, 1280]
        # x_tax: [batch] (Int IDs)
        
        feat_seq = self.bn_input(x_seq)
        feat_tax = self.taxon_embedding(x_tax)
        feat_tax = self.taxon_norm(feat_tax)  
        
        # [Batch, 1280 + 64]
        combined = torch.cat([feat_seq, feat_tax], dim=1)
        
        return self.net(combined)

# ============================================================================
# 3. DATASET C·∫¨P NH·∫¨T (LOAD TAXON)
# ============================================================================
class CAFA6Dataset(Dataset):
    def __init__(self, ids_file, targets_file, embed_dir, num_classes, taxon_pkl):
        path = os.path.join(CONFIG['LABEL_DIR'], ids_file)
        if not os.path.exists(path): path = os.path.join(CONFIG['WORK_DIR'], ids_file)
        self.ids = np.load(path)
        
        t_path = os.path.join(CONFIG['LABEL_DIR'], targets_file)
        if not os.path.exists(t_path): t_path = os.path.join(CONFIG['WORK_DIR'], targets_file)
        with open(t_path, 'rb') as f: self.labels_dict = pickle.load(f)
            
        self.num_classes = num_classes
        self.id_to_embed_idx = {}
        with open(os.path.join(embed_dir, "train_ids.txt"), 'r') as f:
            for idx, line in enumerate(f): self.id_to_embed_idx[line.strip()] = idx
        self.embed_matrix = np.load(os.path.join(embed_dir, "train_embeds.npy"), mmap_mode='r')
        
        # T√¨m file ·ªü work dir ho·∫∑c input dir
        
        if os.path.exists(taxon_pkl):
            tax_path = taxon_pkl
        else:
            tax_path = os.path.join(CONFIG['WORK_DIR'], os.path.basename(taxon_pkl))
            print("‚ö†Ô∏è USING TAXON PKL FROM WORK_DIR:", tax_path)
            
        with open(tax_path, 'rb') as f:
            tax_data = pickle.load(f)
        
        self.prot_to_taxon = tax_data['prot_to_taxon_idx'] 
        self.default_tax = tax_data['num_taxa_classes'] - 1 

    def __len__(self): return len(self.ids)

    def __getitem__(self, idx):
        prot_id = self.ids[idx]
        
        # 1. Embed
        embed_idx = self.id_to_embed_idx.get(prot_id)
        if embed_idx is None:
            feat = torch.zeros(CONFIG["input_dim"], dtype=torch.float32)
        else:
            feat = torch.from_numpy(self.embed_matrix[embed_idx].copy()).float()
        
        # 2. Target
        target = torch.zeros(self.num_classes, dtype=torch.float)
        indices = self.labels_dict.get(prot_id, [])
        if len(indices) > 0: target[indices] = 1.0
            
        # 3. [M·ªöI] Taxon ID
        # L·∫•y Taxon Index, n·∫øu kh√¥ng c√≥ th√¨ tr·∫£ v·ªÅ default (UNK)
        taxon_idx = self.prot_to_taxon.get(prot_id, self.default_tax)
        
        # Tr·∫£ v·ªÅ 3 gi√° tr·ªã
        return feat, torch.tensor(taxon_idx, dtype=torch.long), target


# ============================================================================

# 4. LOSS: ASL OPTIMIZED 

# ============================================================================


class AsymmetricLossOptimized(nn.Module):
    def __init__(self, gamma_neg=4, gamma_pos=0, clip=0.05, eps=1e-8):
        super().__init__()
        self.gamma_neg = gamma_neg
        self.gamma_pos = gamma_pos
        self.clip = clip
        self.eps = eps
    def forward(self, x, y):
        x_sigmoid = torch.sigmoid(x)
        xs_pos = x_sigmoid
        xs_neg = 1 - x_sigmoid
        if self.clip > 0: xs_neg = (xs_neg + self.clip).clamp(max=1)
        pt = y * xs_pos + (1 - y) * xs_neg
        log_pt = torch.log(pt.clamp(min=self.eps))
        pos_weight = (1 - xs_pos) ** self.gamma_pos
        neg_weight = (1 - xs_neg) ** self.gamma_neg
        weighted_loss = - (pos_weight * log_pt * y + neg_weight * log_pt * (1-y))
        return weighted_loss.sum()

# ============================================================================

# 5. METRIC & TRAINING LOOP

# ============================================================================


def calculate_fmax_subset(preds, targets, ia_weights):
    w = ia_weights.reshape(1, -1)
    true_sum = np.sum(targets * w, axis=1)
    valid_mask = true_sum > 0
    if valid_mask.sum() == 0: return 0.0
    p_sub = preds[valid_mask]; t_sub = targets[valid_mask]; w_sub = w; w_true_sub = true_sum[valid_mask]
    best_f1 = 0.0
    thresholds = np.linspace(0.0, 1.0, 51) 
    for tau in thresholds:
        cut = (p_sub >= tau).astype(int)
        tp = np.sum((cut * t_sub) * w_sub, axis=1)
        pred_sum = np.sum(cut * w_sub, axis=1)
        prec = np.divide(tp, pred_sum, out=np.zeros_like(tp), where=pred_sum!=0)
        rec = np.divide(tp, w_true_sub, out=np.zeros_like(tp), where=w_true_sub!=0)
        avg_p = np.mean(prec); avg_r = np.mean(rec)
        if (avg_p + avg_r) > 0: f1 = 2 * avg_p * avg_r / (avg_p + avg_r)
        else: f1 = 0.0
        if f1 > best_f1: best_f1 = f1
    return best_f1

def validate_detailed(model, loader, vocab_df, ia_weights, device):
    model.eval()
    all_preds, all_targets = [], []
    with torch.no_grad():
        for x_seq, x_tax, y in loader:
            x_seq = x_seq.to(device)
            x_tax = x_tax.to(device)
            with autocast(device_type="cuda"):
                logits = model(x_seq, x_tax)
            all_preds.append(torch.sigmoid(logits).cpu().numpy())
            all_targets.append(y.numpy())
    Y_p = np.vstack(all_preds); Y_t = np.vstack(all_targets)
    scores = {}
    for aspect in ['MFO', 'BPO', 'CCO']:
        col_indices = vocab_df.index[vocab_df['aspect'] == aspect].tolist()
        if not col_indices: continue
        scores[aspect] = calculate_fmax_subset(Y_p[:, col_indices], Y_t[:, col_indices], ia_weights[col_indices])
    avg_fmax = np.mean(list(scores.values()))
    return avg_fmax, scores


# ============================================================================
# 5. MAIN TRAINING LOOP
# ============================================================================
def train_c95_taxon():
    print("üöÄ START TRAINING C95 WITH TAXONOMY (SCRATCH)...")
    
    # 1. Load Resources
    vocab_df = pd.read_csv(os.path.join(CONFIG['LABEL_DIR'], CONFIG['VOCAB_FILE']))
    num_classes = len(vocab_df)
    
    try:
        ia_df = pd.read_csv(CONFIG['IA_FILE'], sep='\t', names=['term', 'ia'], header=None)
        ia_map = dict(zip(ia_df.term, ia_df.ia))
        ia_weights = np.array([ia_map.get(t, 1.0) for t in vocab_df.term.values])
    except: ia_weights = np.ones(num_classes)
    
    # 2. Dataset & Model
    tax_pkl_path = CONFIG['TAXON_PKL'] if os.path.exists(CONFIG['TAXON_PKL']) else os.path.join(CONFIG['WORK_DIR'], CONFIG['TAXON_PKL'])
    with open(tax_pkl_path, 'rb') as f: tax_data = pickle.load(f)
    num_taxa = tax_data['num_taxa_classes']
    print(f"   Num Taxa Classes: {num_taxa}")

    train_ds = CAFA6Dataset(CONFIG['TRAIN_IDS'], CONFIG['TARGET_FILE'], CONFIG['EMBED_DIR'], num_classes, CONFIG['TAXON_PKL'])
    val_ds = CAFA6Dataset(CONFIG['VAL_IDS'], CONFIG['TARGET_FILE'], CONFIG['EMBED_DIR'], num_classes, CONFIG['TAXON_PKL'])
    train_loader = DataLoader(train_ds, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=2)
    val_loader = DataLoader(val_ds, batch_size=CONFIG['batch_size']*2, shuffle=False, num_workers=2)
    
    # Init Model 
    model = WideProteinMLP_WithTaxon(
        input_dim=CONFIG['input_dim'], 
        num_classes=num_classes, 
        num_taxa=num_taxa, 
        taxon_dim=CONFIG['taxon_embed_dim'],
        hidden_dims=CONFIG['hidden_dims'], 
        dropout=CONFIG['dropout']
    ).to(CONFIG['device'])
    
    if torch.cuda.device_count() > 1: model = nn.DataParallel(model)
    
    criterion = AsymmetricLossOptimized(gamma_neg=2.5, gamma_pos=0, clip=0.05)
    optimizer = optim.AdamW(model.parameters(), lr=CONFIG['lr'], weight_decay=CONFIG['weight_decay'])
    scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=CONFIG['lr'], steps_per_epoch=len(train_loader), epochs=CONFIG['epochs'])
    scaler = GradScaler("cuda")

    best_score = 0.0
    
    # Loop
    for epoch in range(CONFIG['epochs']):
        model.train()
        loss_sum = 0
        pbar = tqdm(train_loader, desc=f"Ep {epoch+1}", leave=False)
        
        for x_seq, x_tax, y in pbar: 
            x_seq = x_seq.to(CONFIG['device'])
            x_tax = x_tax.to(CONFIG['device'])
            y = y.to(CONFIG['device'])
            
            optimizer.zero_grad()
            with autocast(device_type="cuda"):
                logits = model(x_seq, x_tax)
                loss = criterion(logits, y)
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            with torch.no_grad():
                emb = model.module.taxon_embedding if hasattr(model, "module") else model.taxon_embedding
                unk_idx = model.module.unk_idx if hasattr(model, "module") else model.unk_idx
                unk_vec = model.module.unk_fixed_vector if hasattr(model, "module") else model.unk_fixed_vector
                emb.weight[unk_idx].copy_(unk_vec)
            
            scheduler.step()
            
            

            loss_sum += loss.item()
            pbar.set_postfix({'loss': f"{loss.item():.4f}"})
            
        val_fmax, val_details = validate_detailed(model, val_loader, vocab_df, ia_weights, CONFIG['device'])
        print(f"Epoch {epoch+1}: Loss={loss_sum/len(train_loader):.4f} | Val F-max={val_fmax:.4f} {val_details}")
        
        if val_fmax > best_score:
            best_score = val_fmax
            torch.save(model.state_dict(), "best_model_c95_taxon.pth")
            print("   üèÜ Saved Best Model (With Taxon)!")

if __name__ == "__main__":
    train_c95_taxon()


Epoch 1: Loss=2234.6272 | Val F-max=0.4149 {'MFO': 0.4521210724417265, 'BPO': 0.3122263168422698, 'CCO': 0.48028865099381846}

üèÜ Saved Best Model (With Taxon)!


Epoch 2: Loss=1718.1856 | Val F-max=0.4646 {'MFO': 0.5139700758043899, 'BPO': 0.3504591698159337, 'CCO': 0.5294352775050785}

üèÜ Saved Best Model (With Taxon)!


Epoch 3: Loss=1629.6295 | Val F-max=0.4949 {'MFO': 0.5528053533263427, 'BPO': 0.3677360049264319, 'CCO': 0.5640186982491494}

üèÜ Saved Best Model (With Taxon)!


Epoch 4: Loss=1584.2367 | Val F-max=0.5090 {'MFO': 0.5695660534038522, 'BPO': 0.38302653231172723, 'CCO': 0.5744884560189405}

üèÜ Saved Best Model (With Taxon)!


Epoch 5: Loss=1555.7461 | Val F-max=0.5242 {'MFO': 0.5934636678977137, 'BPO': 0.3967266141178558, 'CCO': 0.5825291143062825}

üèÜ Saved Best Model (With Taxon)!


#### ƒê√°nh gi√° l√† th·∫•t b·∫°i so v·ªõi b·∫£n g·ªëc c·ªßa c95 (k d√πng taxon) => c√≥ l·∫Ω v√¨ b·∫£n embeds ESM2 650M ƒë√£ ƒë·ªß m·∫°nh v·ªõi C95 -> ƒë·ªß nh·∫≠n bi·∫øt c√°c nh√£n trong t·∫≠p C95 -> vi·ªác th√™m th√¥ng tin lo√†i v√†o khi·∫øn b·ªã lo·∫°n h∆°n.