In [5]:
import os
import gc
import math
import pickle
import random
import copy
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.amp import autocast, GradScaler
from tqdm import tqdm
from collections import deque

# =========================================================
# ⚙️ CONFIG (BEST PARAMS + NEW SETTINGS)
# =========================================================
CONFIG = {
    # --- Paths ---
    'TRAIN_IDS': '/root/CAFA6data/c95/train_ids_C95_split.npy',
    'VAL_IDS': '/root/CAFA6data/c95/val_ids_C95_split.npy',
    'TARGETS_PKL': '/root/CAFA6data/c95/train_targets_C95.pkl',
    'EMBED_DIR': '/root/CAFA6data/cafa6-embeds',
    'IA_FILE': '/root/CAFA6data/IA.tsv', 
    'VOCAB_FILE': '/root/CAFA6data/c95/vocab_C95_remove.csv', 
    
    # --- Training Hypers ---
    'input_dim': 1280,
    'batch_size': 256,      
    'device': "cuda",
    'epochs': 30,           
    'lr_max': 4e-4,         
    'seed': 42,            
    'patience': 5,          
    'ema_decay': 0.999,     
      
    'TAXON_PKL': '/root/cafa6/preprocessing/taxon_mapping_K_Species.pkl',
    
    # --- Best Params found ---
    'gamma_neg': 2.5,
    'gamma_pos': 0.0,
    'clip': 0.01,
    
    # --- Output ---
    'save_path': 'c95_mfo_best_ema.pth',
}

# =========================================================
# 1. UTILS: SEED & EMA & IA
# =========================================================
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f"Seed set to {seed}")

class ModelEMA:
    def __init__(self, model, decay=0.999):
        self.module = copy.deepcopy(model).eval()

        self.module.load_state_dict(model.state_dict())

        self.decay = decay

        for p in self.module.parameters():
            p.requires_grad_(False)

    @torch.no_grad()
    def update(self, model):
        msd = model.state_dict()
        for k, ema_v in self.module.state_dict().items():
            model_v = msd[k].detach()
            ema_v.copy_(ema_v * self.decay + (1.0 - self.decay) * model_v)

def load_ia_weights(vocab_path, ia_path):
    print("Loading IA Weights...")
    try:
        vocab_df = pd.read_csv(vocab_path)
        term_list = vocab_df['term'].tolist()
        term_to_idx = {t: i for i, t in enumerate(term_list)}
        num_classes = len(term_list)
        
        ia_weights = np.ones(num_classes, dtype=np.float32)
        if os.path.exists(ia_path):
            with open(ia_path, 'r') as f:
                for line in f:
                    parts = line.strip().split('\t')
                    if len(parts) >= 2:
                        term, ia = parts[0], float(parts[1])
                        if term in term_to_idx:
                            ia_weights[term_to_idx[term]] = ia
            print(f"   Loaded IA for {num_classes} terms.")
        else:
            print(f"   IA file not found. Using weights=1.0")
        return ia_weights
    except Exception as e:
        print(f"   Error loading IA: {e}")
        return np.ones(1, dtype=np.float32)

class ProteinDataset(Dataset):
    def __init__(self, ids_file, targets_pkl, embed_dir, taxon_pkl):
        self.ids = np.load(ids_file)
        self.embeds = np.load(os.path.join(embed_dir, "train_embeds.npy"))
        
        with open(os.path.join(embed_dir, "train_ids.txt")) as f:
            all_ids = [x.strip() for x in f]
        self.id_to_idx = {pid: i for i, pid in enumerate(all_ids)}
        
        with open(targets_pkl, 'rb') as f:
            self.targets = pickle.load(f)
        
        max_idx = 0
        for v in self.targets.values():
            if v: max_idx = max(max_idx, max(v))
        self.num_classes = max_idx + 1
        
        with open(taxon_pkl, 'rb') as f:
            tax_data = pickle.load(f)

        self.prot_to_taxon = tax_data["prot_to_taxon_idx"]
        self.num_taxa = tax_data["num_taxa_classes"]
        self.default_tax = self.num_taxa - 1  

    def __len__(self): return len(self.ids)
    def __getitem__(self, i):
        pid = self.ids[i]
        
        # Embedding
        emb_idx = self.id_to_idx.get(pid)
        feat = torch.tensor(self.embeds[emb_idx], dtype=torch.float32) if emb_idx is not None else torch.zeros(1280)
        
        target = torch.zeros(self.num_classes, dtype=torch.float32)
        if pid in self.targets:
            inds = self.targets[pid]
            if len(inds) > 0:
                inds = [int(x) for x in inds]
                target[inds] = 1.0
                
        tax_id = self.prot_to_taxon.get(pid, self.default_tax)
        tax_id = torch.tensor(tax_id, dtype=torch.long)
        
        return feat, tax_id, target

# =========================================================
# 2. MODEL & LOSS
# =========================================================
class AsymmetricLossOptimized(nn.Module):
    def __init__(self, gamma_neg=4, gamma_pos=0, clip=0.05, eps=1e-7):
        super().__init__()
        self.gamma_neg = gamma_neg
        self.gamma_pos = gamma_pos
        self.clip = clip
        self.eps = eps
        
    def forward(self, x, y):
        x = x.float()
        y = y.float()
        
        x_sigmoid = torch.sigmoid(x)
        xs_pos = x_sigmoid
        xs_neg = 1 - x_sigmoid
        if self.clip > 0:
            xs_neg = (xs_neg + self.clip + self.eps).clamp(max=1.0)
        else:
            xs_neg = (xs_neg + self.eps).clamp(max=1.0)
        
        pt = y * xs_pos + (1 - y) * xs_neg
        log_pt = torch.log(pt.clamp(min=self.eps, max=1.0))
        
        pos_weight = (1 - xs_pos) ** self.gamma_pos
        neg_weight = (1 - xs_neg) ** self.gamma_neg
        
        weighted_loss = - (pos_weight * log_pt * y + neg_weight * log_pt * (1-y))
        return weighted_loss.sum() / x.size(0)

class WideProteinMLP_WithTaxon(nn.Module):
    def __init__(self, input_dim, num_classes, num_taxa, taxon_dim=64,
                 hidden_dims=[4096, 4096], dropout=0.4):
        super().__init__()

        # Normalize protein embedding
        self.seq_norm = nn.LayerNorm(input_dim)

        # Taxon branch
        self.taxon_embedding = nn.Embedding(num_taxa, taxon_dim)
        self.taxon_norm = nn.LayerNorm(taxon_dim)

        # UNK taxon
        self.unk_idx = num_taxa - 1
        with torch.no_grad():
            self.taxon_embedding.weight[self.unk_idx].zero_()

        self.register_buffer(
            "unk_vec", 
            self.taxon_embedding.weight[self.unk_idx].clone().detach()
        )
 
        combined_dim = input_dim + taxon_dim

        layers = []
        prev = combined_dim

        for h in hidden_dims:
            layers.append(nn.Linear(prev, h))
            layers.append(nn.GELU())
            layers.append(nn.Dropout(dropout))
            prev = h
        
        layers.append(nn.Linear(prev, num_classes))

        self.net = nn.Sequential(*layers)

    def forward(self, seq_emb, taxon_id):
        seq = self.seq_norm(seq_emb)
        tax = self.taxon_norm(self.taxon_embedding(taxon_id))
        x = torch.cat([seq, tax], dim=1)
        return self.net(x)

def gpu_calculate_fmax(preds_t, targets_t, ia_weights_np, device):
    if preds_t.numel() == 0: return 0.0
    ia_t = torch.from_numpy(ia_weights_np.astype(np.float32)).to(device)
    w = ia_t.unsqueeze(0)
    true_sum = (targets_t * w).sum(dim=1)
    valid_mask = true_sum > 0
    if valid_mask.sum().item() == 0: return 0.0
    
    p_sub = preds_t[valid_mask]
    t_sub = targets_t[valid_mask]
    w_sub = w
    w_true_sub = true_sum[valid_mask]
    
    best_f1 = 0.0
    thresholds = torch.linspace(0.0, 1.0, 51, device=device)
    for tau in thresholds:
        cut = (p_sub >= tau).float()
        tp = ((cut * t_sub) * w_sub).sum(dim=1)
        pred_sum = (cut * w_sub).sum(dim=1)
        prec = torch.where(pred_sum != 0, tp / pred_sum, torch.zeros_like(tp))
        rec = torch.where(w_true_sub != 0, tp / w_true_sub, torch.zeros_like(tp))
        avg_p = prec.mean(); avg_r = rec.mean()
        denom = (avg_p + avg_r)
        if denom > 0: f1 = (2.0 * avg_p * avg_r / denom).item()
        else: f1 = 0.0
        if f1 > best_f1: best_f1 = f1
    return float(best_f1)

def gpu_fmax_split(preds_t, targets_t, ia_t):
    """
    preds_t: Tensor [N, C]
    targets_t: Tensor [N, C]
    ia_t: Tensor [C]
    """
    device = preds_t.device
    w = ia_t.unsqueeze(0)                # [1, C]
    true_sum = (targets_t * w).sum(1)    # [N]

    valid = true_sum > 0
    if valid.sum() == 0:
        return 0.0

    p = preds_t[valid]     # [Nv, C]
    t = targets_t[valid]
    ts = true_sum[valid]

    best = 0.0
    thresholds = torch.linspace(0, 1, 101, device=device)

    for tau in thresholds:
        cut = (p >= tau).float()

        tp = (cut * t * w).sum(1)
        pred_sum = (cut * w).sum(1)

        prec = torch.where(pred_sum > 0, tp / pred_sum, torch.zeros_like(tp))
        rec = torch.where(ts > 0, tp / ts, torch.zeros_like(tp))

        avg_p = prec.mean()
        avg_r = rec.mean()
        denom = avg_p + avg_r

        f1 = torch.where(denom > 0, 2 * avg_p * avg_r / denom, torch.tensor(0., device=device))
        best = max(best, f1.item())

    return best

# =========================================================
# 3. TRAINING LOOP
# =========================================================
def train_epoch(model, ema_model, loader, loss_fn, optimizer, scaler, scheduler, device):
    model.train()
    running_loss = 0.0
    pbar = tqdm(loader, desc="Training")
    
    for feats, tax_id, labels in pbar:
        feats = feats.to(device)
        tax_id = tax_id.to(device)     
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        with autocast(device_type="cuda", dtype=torch.bfloat16):
            logits = model(feats, tax_id)
            loss = loss_fn(logits, labels)
            
        if not torch.isfinite(loss):
            print("!!! LOSS NaN or Inf detected.")
            print("max_logit", logits.max().item(), "min_logit", logits.min().item())
            torch.save(model.state_dict(), "crash_before_nan.pth")
            raise RuntimeError("NaN in loss")
        
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        
        if ema_model is not None:
            ema_model.update(model)
        
        running_loss += float(loss.item())

        del feats, tax_id, labels, logits, loss
        
    return running_loss / len(loader)

@torch.no_grad()
def validate_split_go_gpu(model, loader, loss_fn, device, ia_weights, vocab_df):
    model.eval()

    preds_list = []
    targets_list = []
    total_loss = 0.0

    for feats, tax_id, labels in loader:
        feats = feats.to(device)
        tax_id = tax_id.to(device)
        labels = labels.to(device)

        logits = model(feats, tax_id)
        loss = loss_fn(logits, labels)
        total_loss += loss.item()

        preds_list.append(torch.sigmoid(logits))
        targets_list.append(labels)

    preds_t = torch.cat(preds_list, dim=0).to(device)
    targets_t = torch.cat(targets_list, dim=0).to(device)
    val_loss = total_loss / len(loader)

    # Convert IA to torch
    ia_t = torch.tensor(ia_weights, dtype=torch.float32, device=device)

    scores = {}

    for asp in ["MFO", "BPO", "CCO"]:
        cols = vocab_df.index[vocab_df["aspect"] == asp].tolist()
        if not cols:
            continue

        scores[asp] = gpu_fmax_split(
            preds_t[:, cols],
            targets_t[:, cols],
            ia_t[cols]
        )

    avg_fmax = np.mean(list(scores.values()))

    return val_loss, avg_fmax, scores

# =========================================================
# MAIN
# =========================================================
def main():
    # 1. Init Seed
    seed_everything(CONFIG['seed'])
    
    # 2. Load Data & IA
    ia_weights = load_ia_weights(CONFIG['VOCAB_FILE'], CONFIG['IA_FILE'])
    train_ds = ProteinDataset(CONFIG['TRAIN_IDS'], CONFIG['TARGETS_PKL'], CONFIG['EMBED_DIR'], CONFIG['TAXON_PKL'])
    val_ds    = ProteinDataset(CONFIG['VAL_IDS'], CONFIG['TARGETS_PKL'], CONFIG['EMBED_DIR'], CONFIG['TAXON_PKL'])
    
    if len(ia_weights) != train_ds.num_classes:
        print(f"Resizing IA: {len(ia_weights)} -> {train_ds.num_classes}")
        new_ia = np.ones(train_ds.num_classes, dtype=np.float32)
        min_len = min(len(ia_weights), train_ds.num_classes)
        new_ia[:min_len] = ia_weights[:min_len]
        ia_weights = new_ia

    CONFIG['output_classes'] = train_ds.num_classes
    
    train_loader = DataLoader(train_ds, batch_size=CONFIG['batch_size'], shuffle=True, 
                              num_workers=16, pin_memory=True, persistent_workers=False, prefetch_factor=2)
    val_loader    = DataLoader(val_ds, batch_size=CONFIG['batch_size'], shuffle=False, 
                               num_workers=16, pin_memory=True, persistent_workers=False, prefetch_factor=2)

    # 3. Init Model & EMA
    print(f"START OFFICIAL TRAINING V2 (C95 MFO | Seed: {CONFIG['seed']})")
    model = WideProteinMLP_WithTaxon(
        input_dim=CONFIG['input_dim'],
        num_classes=CONFIG['output_classes'],
        num_taxa=train_ds.num_taxa,
        taxon_dim=64,
        hidden_dims=[4096, 4096],    
        dropout=0.4
    ).to(CONFIG['device'])
    
    ema_model = ModelEMA(model, decay=CONFIG['ema_decay'])
    print(f"   Using EMA Decay: {CONFIG['ema_decay']}")
    
    loss_fn = AsymmetricLossOptimized(gamma_neg=CONFIG['gamma_neg'], gamma_pos=CONFIG['gamma_pos'], clip=CONFIG['clip'])
    optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG['lr_max'], weight_decay=0.02)
    scaler = GradScaler()
    
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=CONFIG['lr_max'], epochs=CONFIG['epochs'], 
        steps_per_epoch=len(train_loader), pct_start=0.3, div_factor=25, final_div_factor=1e4
    )

    # 4. Training Loop 
    best_proxy_fmax = 0.0
    patience_counter = 0

    fmax_buffer = deque(maxlen=3)   # window = 3 
    
    vocab_df = pd.read_csv(CONFIG['VOCAB_FILE'])
    
    for epoch in range(CONFIG['epochs']):
        tr_loss = train_epoch(model, ema_model, train_loader, loss_fn, optimizer, scaler, scheduler, CONFIG['device'])
        
        # Validate EMA Model (Usually better)
        val_loss, val_fmax, go_scores = validate_split_go_gpu(
            ema_model.module,
            val_loader,
            loss_fn,
            CONFIG["device"],
            ia_weights,
            vocab_df
        )
        
        print(f"Epoch {epoch+1}/{CONFIG['epochs']} | Train Loss: {tr_loss:.4f} | Val Loss: {val_loss:.4f} | EMA F-max: {val_fmax:.4f}")
        print("MF:", go_scores["MFO"], "BP:", go_scores["BPO"], "CC:", go_scores["CCO"])
         #  cập nhật buffer
        fmax_buffer.append(val_fmax)

        if len(fmax_buffer) == fmax_buffer.maxlen:
            proxy_fmax = float(np.mean(fmax_buffer))  
        else:
            proxy_fmax = val_fmax
            
        if proxy_fmax > best_proxy_fmax:
            best_proxy_fmax = proxy_fmax
            patience_counter = 0
            torch.save(ema_model.module.state_dict(), CONFIG['save_path'])
            print(f"    Saved Best EMA Model (proxy F-max: {best_proxy_fmax:.4f})")
        else:
            patience_counter += 1
            print(f"    Patience: {patience_counter}/{CONFIG['patience']}")
            if patience_counter >= CONFIG['patience']:
                print(" Early Stopping Triggered!")
                break
            
    print(f"\nFinished Training. Best Model saved at: {CONFIG['save_path']}")

if __name__ == "__main__":
    main()

Seed set to 42
Loading IA Weights...
   Loaded IA for 6413 terms.


START OFFICIAL TRAINING V2 (C95 MFO | Seed: 42)
   Using EMA Decay: 0.999


Training: 100%|██████████| 258/258 [00:09<00:00, 26.95it/s]


Epoch 1/30 | Train Loss: 136.6985 | Val Loss: 587.4654 | EMA F-max: 0.0612
MF: 0.026562336832284927 BP: 0.0501042902469635 CC: 0.10678490251302719
    Saved Best EMA Model (proxy F-max: 0.0612)


Training: 100%|██████████| 258/258 [00:07<00:00, 33.64it/s]


Epoch 2/30 | Train Loss: 62.9329 | Val Loss: 336.8705 | EMA F-max: 0.3520
MF: 0.4032348096370697 BP: 0.24300503730773926 CC: 0.4097551703453064
    Saved Best EMA Model (proxy F-max: 0.3520)


Training: 100%|██████████| 258/258 [00:07<00:00, 34.17it/s]


Epoch 3/30 | Train Loss: 59.0222 | Val Loss: 173.2320 | EMA F-max: 0.3922
MF: 0.42817485332489014 BP: 0.2857500910758972 CC: 0.46279221773147583
    Patience: 1/5


Training: 100%|██████████| 258/258 [00:07<00:00, 34.43it/s]


Epoch 4/30 | Train Loss: 56.8788 | Val Loss: 100.1480 | EMA F-max: 0.4181
MF: 0.4554789364337921 BP: 0.30779826641082764 CC: 0.4909173846244812
    Saved Best EMA Model (proxy F-max: 0.3874)


Training: 100%|██████████| 258/258 [00:07<00:00, 33.92it/s]


Epoch 5/30 | Train Loss: 55.4363 | Val Loss: 72.7233 | EMA F-max: 0.4441
MF: 0.48857492208480835 BP: 0.3270297944545746 CC: 0.5166446566581726
    Saved Best EMA Model (proxy F-max: 0.4181)


Training: 100%|██████████| 258/258 [00:07<00:00, 35.37it/s]


Epoch 6/30 | Train Loss: 54.2737 | Val Loss: 62.8351 | EMA F-max: 0.4651
MF: 0.5134793519973755 BP: 0.34440097212791443 CC: 0.5374487042427063
    Saved Best EMA Model (proxy F-max: 0.4424)


Training: 100%|██████████| 258/258 [00:07<00:00, 32.40it/s]


Epoch 7/30 | Train Loss: 53.3485 | Val Loss: 58.5686 | EMA F-max: 0.4822
MF: 0.5347944498062134 BP: 0.3583384156227112 CC: 0.5534701943397522
    Saved Best EMA Model (proxy F-max: 0.4638)


Training: 100%|██████████| 258/258 [00:07<00:00, 34.95it/s]


Epoch 8/30 | Train Loss: 52.5392 | Val Loss: 56.2894 | EMA F-max: 0.4961
MF: 0.5534444451332092 BP: 0.370866060256958 CC: 0.5639670491218567
    Saved Best EMA Model (proxy F-max: 0.4811)


Training: 100%|██████████| 258/258 [00:07<00:00, 34.73it/s]


Epoch 9/30 | Train Loss: 51.6688 | Val Loss: 54.8128 | EMA F-max: 0.5065
MF: 0.5669605731964111 BP: 0.3796241283416748 CC: 0.5730208158493042
    Saved Best EMA Model (proxy F-max: 0.4949)


Training: 100%|██████████| 258/258 [00:07<00:00, 34.52it/s]


Epoch 10/30 | Train Loss: 50.6988 | Val Loss: 53.7375 | EMA F-max: 0.5145
MF: 0.576317548751831 BP: 0.38816016912460327 CC: 0.5790892839431763
    Saved Best EMA Model (proxy F-max: 0.5057)


Training: 100%|██████████| 258/258 [00:07<00:00, 34.91it/s]


Epoch 11/30 | Train Loss: 49.7653 | Val Loss: 52.9718 | EMA F-max: 0.5221
MF: 0.5867264270782471 BP: 0.3952581584453583 CC: 0.5843107104301453
    Saved Best EMA Model (proxy F-max: 0.5144)


Training: 100%|██████████| 258/258 [00:07<00:00, 35.26it/s]


Epoch 12/30 | Train Loss: 48.7291 | Val Loss: 52.4399 | EMA F-max: 0.5285
MF: 0.5959430932998657 BP: 0.400628924369812 CC: 0.5888416171073914
    Saved Best EMA Model (proxy F-max: 0.5217)


Training: 100%|██████████| 258/258 [00:07<00:00, 34.46it/s]


Epoch 13/30 | Train Loss: 47.7368 | Val Loss: 52.1093 | EMA F-max: 0.5336
MF: 0.6030508875846863 BP: 0.4057435393333435 CC: 0.591980516910553
    Saved Best EMA Model (proxy F-max: 0.5281)


Training: 100%|██████████| 258/258 [00:07<00:00, 34.64it/s]


Epoch 14/30 | Train Loss: 46.6006 | Val Loss: 51.9380 | EMA F-max: 0.5379
MF: 0.6084147691726685 BP: 0.4096858501434326 CC: 0.5955181121826172
    Saved Best EMA Model (proxy F-max: 0.5333)


Training: 100%|██████████| 258/258 [00:07<00:00, 34.88it/s]


Epoch 15/30 | Train Loss: 45.4965 | Val Loss: 51.8894 | EMA F-max: 0.5409
MF: 0.6118411421775818 BP: 0.41235682368278503 CC: 0.5983776450157166
    Saved Best EMA Model (proxy F-max: 0.5374)


Training: 100%|██████████| 258/258 [00:07<00:00, 35.37it/s]


Epoch 16/30 | Train Loss: 44.3022 | Val Loss: 51.9515 | EMA F-max: 0.5434
MF: 0.6151878833770752 BP: 0.4148193597793579 CC: 0.600284993648529
    Saved Best EMA Model (proxy F-max: 0.5407)


Training: 100%|██████████| 258/258 [00:07<00:00, 34.67it/s]


Epoch 17/30 | Train Loss: 43.1928 | Val Loss: 52.1022 | EMA F-max: 0.5453
MF: 0.6175118684768677 BP: 0.4167115092277527 CC: 0.6015477776527405
    Saved Best EMA Model (proxy F-max: 0.5432)


Training: 100%|██████████| 258/258 [00:07<00:00, 35.20it/s]


Epoch 18/30 | Train Loss: 42.0196 | Val Loss: 52.3326 | EMA F-max: 0.5463
MF: 0.6192262172698975 BP: 0.4176892936229706 CC: 0.6018440127372742
    Saved Best EMA Model (proxy F-max: 0.5450)


Training: 100%|██████████| 258/258 [00:07<00:00, 34.76it/s]


Epoch 19/30 | Train Loss: 40.8864 | Val Loss: 52.6116 | EMA F-max: 0.5472
MF: 0.6204292178153992 BP: 0.41845881938934326 CC: 0.6026325225830078
    Saved Best EMA Model (proxy F-max: 0.5462)


Training: 100%|██████████| 258/258 [00:07<00:00, 34.63it/s]


Epoch 20/30 | Train Loss: 39.7551 | Val Loss: 52.9515 | EMA F-max: 0.5479
MF: 0.6210730671882629 BP: 0.4187500774860382 CC: 0.60383540391922
    Saved Best EMA Model (proxy F-max: 0.5471)


Training: 100%|██████████| 258/258 [00:07<00:00, 34.29it/s]


Epoch 21/30 | Train Loss: 38.7544 | Val Loss: 53.3294 | EMA F-max: 0.5480
MF: 0.6214264035224915 BP: 0.41843876242637634 CC: 0.6041554808616638
    Saved Best EMA Model (proxy F-max: 0.5477)


Training: 100%|██████████| 258/258 [00:07<00:00, 34.63it/s]


Epoch 22/30 | Train Loss: 37.8160 | Val Loss: 53.7394 | EMA F-max: 0.5476
MF: 0.6209926009178162 BP: 0.418040007352829 CC: 0.6037581562995911
    Saved Best EMA Model (proxy F-max: 0.5478)


Training: 100%|██████████| 258/258 [00:07<00:00, 34.18it/s]


Epoch 23/30 | Train Loss: 36.9200 | Val Loss: 54.1825 | EMA F-max: 0.5469
MF: 0.6197471022605896 BP: 0.4172874689102173 CC: 0.6037395596504211
    Patience: 1/5


Training: 100%|██████████| 258/258 [00:07<00:00, 34.45it/s]


Epoch 24/30 | Train Loss: 36.2521 | Val Loss: 54.6315 | EMA F-max: 0.5458
MF: 0.6183481216430664 BP: 0.41663169860839844 CC: 0.602554202079773
    Patience: 2/5


Training: 100%|██████████| 258/258 [00:07<00:00, 35.24it/s]


Epoch 25/30 | Train Loss: 35.6059 | Val Loss: 55.0818 | EMA F-max: 0.5454
MF: 0.6181322932243347 BP: 0.41574472188949585 CC: 0.6023028492927551
    Patience: 3/5


Training: 100%|██████████| 258/258 [00:07<00:00, 34.02it/s]


Epoch 26/30 | Train Loss: 35.1198 | Val Loss: 55.5036 | EMA F-max: 0.5444
MF: 0.6168670654296875 BP: 0.4144900143146515 CC: 0.6018363833427429
    Patience: 4/5


Training: 100%|██████████| 258/258 [00:07<00:00, 35.30it/s]


Epoch 27/30 | Train Loss: 34.7647 | Val Loss: 55.8992 | EMA F-max: 0.5437
MF: 0.6161227226257324 BP: 0.41374918818473816 CC: 0.6012502908706665
    Patience: 5/5
 Early Stopping Triggered!

Finished Training. Best Model saved at: c95_mfo_best_ema.pth


In [7]:


# =========================================================
#  CONFIG (CHẾ ĐỘ FULL TRAIN)
# =========================================================
CONFIG = {
    # --- Paths ---
    'TRAIN_IDS': '/root/CAFA6data/c95/train_ids_C95_split.npy',
    'VAL_IDS': '/root/CAFA6data/c95/val_ids_C95_split.npy',
    'TARGETS_PKL': '/root/CAFA6data/c95/train_targets_C95.pkl',
    'EMBED_DIR': '/root/CAFA6data/cafa6-embeds',
    'IA_FILE': '/root/CAFA6data/IA.tsv', 
    'VOCAB_FILE': '/root/CAFA6data/c95/vocab_C95_remove.csv', 
    'TAXON_PKL': '/root/cafa6/preprocessing/taxon_mapping_K_Species.pkl',
    
    # --- Training Hypers ---
    'input_dim': 1280,
    'batch_size': 256,      
    'device': "cuda",
          
    'lr_max': 4e-4,        
    'seed': 42,            
    'ema_decay': 0.999,   
    
    'epochs': 30,
    'stop_epoch': 22,
    
    # --- Params  ---
    'gamma_neg': 2.5,
    'gamma_pos': 0.0,
    'clip': 0.01,
    
}

# =========================================================
# 1. UTILS
# =========================================================
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f" Seed set to {seed}")

class ModelEMA:
    def __init__(self, model, decay=0.999):
        self.module = copy.deepcopy(model).eval()
        self.module.load_state_dict(model.state_dict())
        self.decay = decay
        for p in self.module.parameters(): p.requires_grad_(False)

    @torch.no_grad()
    def update(self, model):
        msd = model.state_dict()
        for k, ema_v in self.module.state_dict().items():
            model_v = msd[k].detach()
            ema_v.copy_(ema_v * self.decay + (1.0 - self.decay) * model_v)

def load_ia_weights(vocab_path, ia_path):
    print(" Loading IA Weights...")
    try:
        vocab_df = pd.read_csv(vocab_path)
        term_list = vocab_df['term'].tolist()
        term_to_idx = {t: i for i, t in enumerate(term_list)}
        num_classes = len(term_list)
        ia_weights = np.ones(num_classes, dtype=np.float32)
        if os.path.exists(ia_path):
            with open(ia_path, 'r') as f:
                for line in f:
                    parts = line.strip().split('\t')
                    if len(parts) >= 2:
                        term, ia = parts[0], float(parts[1])
                        if term in term_to_idx:
                            ia_weights[term_to_idx[term]] = ia
        return ia_weights
    except Exception:
        return np.ones(1, dtype=np.float32)

class ProteinDataset(Dataset):
    def __init__(self, ids_source, targets_pkl, embed_dir, taxon_pkl):
        
        if isinstance(ids_source, str):
            self.ids = np.load(ids_source)
        else:
            self.ids = ids_source 
            
        # Dùng mmap_mode='r' để tiết kiệm RAM
        self.embeds = np.load(os.path.join(embed_dir, "train_embeds.npy"), mmap_mode='r')
        
        with open(os.path.join(embed_dir, "train_ids.txt")) as f:
            all_ids = [x.strip() for x in f]
        self.id_to_idx = {pid: i for i, pid in enumerate(all_ids)}
        
        with open(targets_pkl, 'rb') as f:
            self.targets = pickle.load(f)
        
        max_idx = 0
        for v in self.targets.values():
            if v: max_idx = max(max_idx, max(v))
        self.num_classes = max_idx + 1
        
        with open(taxon_pkl, 'rb') as f:
            tax_data = pickle.load(f)
        self.prot_to_taxon = tax_data["prot_to_taxon_idx"]
        self.num_taxa = tax_data["num_taxa_classes"]
        self.default_tax = self.num_taxa - 1

    def __len__(self): return len(self.ids)
    def __getitem__(self, i):
        pid = self.ids[i]
        emb_idx = self.id_to_idx.get(pid)
        if emb_idx is not None:
             # Copy vào RAM
            feat = torch.tensor(self.embeds[emb_idx], dtype=torch.float32)
        else:
            feat = torch.zeros(1280)
        
        target = torch.zeros(self.num_classes, dtype=torch.float32)
        if pid in self.targets:
            inds = self.targets[pid]
            if len(inds) > 0:
                target[inds] = 1.0
                
        tax_id = self.prot_to_taxon.get(pid, self.default_tax)
        return feat, torch.tensor(tax_id, dtype=torch.long), target

# =========================================================
# 2. MODEL & LOSS
# =========================================================
class AsymmetricLossOptimized(nn.Module):
    def __init__(self, gamma_neg=4, gamma_pos=0, clip=0.05, eps=1e-7):
        super().__init__()
        self.gamma_neg = gamma_neg
        self.gamma_pos = gamma_pos
        self.clip = clip
        self.eps = eps
    def forward(self, x, y):
        x = x.float(); y = y.float()
        x_sigmoid = torch.sigmoid(x)
        xs_pos = x_sigmoid
        xs_neg = (1 - x_sigmoid + self.clip + self.eps).clamp(max=1.0)
        pt = y * xs_pos + (1 - y) * xs_neg
        log_pt = torch.log(pt.clamp(min=self.eps, max=1.0))
        weight = y*(1-xs_pos)**self.gamma_pos + (1-y)*(1-xs_neg)**self.gamma_neg
        loss = - weight * log_pt
        return loss.sum() / x.size(0)

class WideProteinMLP_WithTaxon(nn.Module):
    def __init__(self, input_dim, num_classes, num_taxa, taxon_dim=64, hidden_dims=[4096, 4096], dropout=0.4):
        super().__init__()
        self.seq_norm = nn.LayerNorm(input_dim)
        self.taxon_embedding = nn.Embedding(num_taxa, taxon_dim)
        self.taxon_norm = nn.LayerNorm(taxon_dim)
        with torch.no_grad(): self.taxon_embedding.weight[num_taxa-1].zero_()
        layers = []
        prev = input_dim + taxon_dim
        for h in hidden_dims:
            layers.append(nn.Linear(prev, h))
            layers.append(nn.GELU())
            layers.append(nn.Dropout(dropout))
            prev = h
        layers.append(nn.Linear(prev, num_classes))
        self.net = nn.Sequential(*layers)
    def forward(self, seq, tax):
        x = torch.cat([self.seq_norm(seq), self.taxon_norm(self.taxon_embedding(tax))], dim=1)
        return self.net(x)


def train_epoch(model, ema_model, loader, loss_fn, optimizer, scaler, scheduler, device):
    model.train()
    running_loss = 0.0
    pbar = tqdm(loader, desc="Training Full")
    
    for feats, tax_id, labels in pbar:
        feats = feats.to(device)
        tax_id = tax_id.to(device)     
        labels = labels.to(device)
        
        optimizer.zero_grad()
        with autocast(device_type="cuda", dtype=torch.bfloat16):
            logits = model(feats, tax_id)
            loss = loss_fn(logits, labels)
        
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        scaler.step(optimizer); scaler.update(); scheduler.step()
        
        if ema_model: ema_model.update(model)
        
        running_loss += float(loss.item())
        del feats, tax_id, labels, logits, loss
    
    return running_loss / len(loader)

# =========================================================
# MAIN
# =========================================================
def main():
    seed_everything(CONFIG['seed'])

    print("Merging Train + Val data for FULL TRAINING...")
    ids_train = np.load(CONFIG['TRAIN_IDS'])
    ids_val   = np.load(CONFIG['VAL_IDS'])
    ids_full  = np.concatenate([ids_train, ids_val])

    full_ds = ProteinDataset(ids_full, CONFIG['TARGETS_PKL'], CONFIG['EMBED_DIR'], CONFIG['TAXON_PKL'])
    full_loader = DataLoader(full_ds, batch_size=CONFIG['batch_size'], shuffle=True,
                            num_workers=16, pin_memory=True, persistent_workers=False)

    # Load IA Weights cho metric
    ia_weights = load_ia_weights(CONFIG['VOCAB_FILE'], CONFIG['IA_FILE'])
    if len(ia_weights) != full_ds.num_classes:
        new_ia = np.ones(full_ds.num_classes, dtype=np.float32)
        min_len = min(len(ia_weights), full_ds.num_classes)
        new_ia[:min_len] = ia_weights[:min_len]
        ia_weights = new_ia

    model = WideProteinMLP_WithTaxon(
        input_dim=CONFIG['input_dim'],
        num_classes=full_ds.num_classes,
        num_taxa=full_ds.num_taxa,
        taxon_dim=64,
        hidden_dims=[4096, 4096],
        dropout=0.4
    ).to(CONFIG['device'])

    ema_model = ModelEMA(model, decay=CONFIG['ema_decay'])
    
    loss_fn = AsymmetricLossOptimized(
        gamma_neg=CONFIG['gamma_neg'],
        gamma_pos=CONFIG['gamma_pos'],
        clip=CONFIG['clip']
    )

    optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG['lr_max'], weight_decay=0.02)
    scaler = GradScaler()

    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=CONFIG['lr_max'], 
        epochs=CONFIG['epochs'],                
        steps_per_epoch=len(full_loader),
        pct_start=0.3,
        div_factor=25,
        final_div_factor=1e4
    )

    print(" FINAL TRAINING START (full data, 22 epochs)")

    for epoch in range(CONFIG['epochs']):
        tr_loss = train_epoch(
            model, ema_model, full_loader,
            loss_fn, optimizer, scaler,
            scheduler, CONFIG['device']
        )
        
        current_lr = scheduler.get_last_lr()[0]
        print(
            f"Epoch {epoch+1}/{CONFIG['epochs']} | "
            f"Loss: {tr_loss:.4f} | "
            f"LR: {current_lr:.2e}"
        )

        # ===============================
        #  FORCE STOP AT BEST EPOCH
        # ===============================
        if epoch + 1 >= CONFIG['stop_epoch']:
            print(
                f" Forced stop at epoch {epoch+1} "
                f"(observed F-max peak)"
            )
            break

    # Save EMA as final model
    torch.save(ema_model.module.state_dict(), "final_cafa6_model_c95.pth")
    
if __name__ == "__main__":
    main()

 Seed set to 42
Merging Train + Val data for FULL TRAINING...


 Loading IA Weights...
 FINAL TRAINING START (full data, 25 epochs)


Training Full: 100%|██████████| 322/322 [00:09<00:00, 33.44it/s]


Epoch 1/30 | Loss: 122.6740 | LR: 2.76e-05


Training Full: 100%|██████████| 322/322 [00:10<00:00, 32.15it/s]


Epoch 2/30 | Loss: 61.7423 | LR: 6.09e-05


Training Full: 100%|██████████| 322/322 [00:10<00:00, 29.90it/s]


Epoch 3/30 | Loss: 58.1034 | LR: 1.12e-04


Training Full: 100%|██████████| 322/322 [00:08<00:00, 36.76it/s]


Epoch 4/30 | Loss: 56.1127 | LR: 1.75e-04


Training Full: 100%|██████████| 322/322 [00:09<00:00, 35.77it/s]


Epoch 5/30 | Loss: 54.7940 | LR: 2.41e-04


Training Full: 100%|██████████| 322/322 [00:08<00:00, 36.62it/s]


Epoch 6/30 | Loss: 53.7238 | LR: 3.04e-04


Training Full: 100%|██████████| 322/322 [00:09<00:00, 35.00it/s]


Epoch 7/30 | Loss: 52.8780 | LR: 3.55e-04


Training Full: 100%|██████████| 322/322 [00:08<00:00, 36.21it/s]


Epoch 8/30 | Loss: 52.0805 | LR: 3.88e-04


Training Full: 100%|██████████| 322/322 [00:08<00:00, 37.05it/s]


Epoch 9/30 | Loss: 51.2514 | LR: 4.00e-04


Training Full: 100%|██████████| 322/322 [00:09<00:00, 35.37it/s]


Epoch 10/30 | Loss: 50.3157 | LR: 3.98e-04


Training Full: 100%|██████████| 322/322 [00:09<00:00, 35.01it/s]


Epoch 11/30 | Loss: 49.3508 | LR: 3.91e-04


Training Full: 100%|██████████| 322/322 [00:08<00:00, 37.69it/s]


Epoch 12/30 | Loss: 48.3827 | LR: 3.80e-04


Training Full: 100%|██████████| 322/322 [00:08<00:00, 36.81it/s]


Epoch 13/30 | Loss: 47.3112 | LR: 3.65e-04


Training Full: 100%|██████████| 322/322 [00:08<00:00, 36.53it/s]


Epoch 14/30 | Loss: 46.2920 | LR: 3.47e-04


Training Full: 100%|██████████| 322/322 [00:08<00:00, 36.17it/s]


Epoch 15/30 | Loss: 45.1066 | LR: 3.25e-04


Training Full: 100%|██████████| 322/322 [00:09<00:00, 35.64it/s]


Epoch 16/30 | Loss: 43.9822 | LR: 3.00e-04


Training Full: 100%|██████████| 322/322 [00:09<00:00, 35.66it/s]


Epoch 17/30 | Loss: 42.8208 | LR: 2.73e-04


Training Full: 100%|██████████| 322/322 [00:08<00:00, 37.09it/s]


Epoch 18/30 | Loss: 41.7068 | LR: 2.44e-04


Training Full: 100%|██████████| 322/322 [00:08<00:00, 36.32it/s]


Epoch 19/30 | Loss: 40.5587 | LR: 2.15e-04


Training Full: 100%|██████████| 322/322 [00:08<00:00, 36.46it/s]


Epoch 20/30 | Loss: 39.4975 | LR: 1.85e-04


Training Full: 100%|██████████| 322/322 [00:09<00:00, 34.68it/s]


Epoch 21/30 | Loss: 38.4900 | LR: 1.55e-04


Training Full: 100%|██████████| 322/322 [00:08<00:00, 35.79it/s]


Epoch 22/30 | Loss: 37.5596 | LR: 1.27e-04
 Forced stop at epoch 22 (observed F-max peak)
