In [None]:
import os
import gc
import math
import pickle
import random
import copy
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.amp import autocast, GradScaler
from tqdm import tqdm
from collections import deque

# =========================================================
# CONFIG (BEST PARAMS + NEW SETTINGS)
# =========================================================
CONFIG = {
    # --- Paths ---
    'TRAIN_IDS': '/root/CAFA6data/c99/train_ids_C99_split.npy',
    'VAL_IDS': '/root/CAFA6data/c99/val_ids_C99_split.npy',
    'TARGETS_PKL': '/root/CAFA6data/c99/train_targets_C99.pkl',
    'EMBED_DIR': '/root/CAFA6data/cafa6-embeds',
    'IA_FILE': '/root/CAFA6data/IA.tsv',
    'VOCAB_FILE': '/root/CAFA6data/c99/vocab_C99_remove.csv',

    'TAXON_PKL': '/root/cafa6/preprocessing/taxon_mapping_K_Species.pkl',

    # --- Training Hypers ---
    'input_dim': 1280,
    'batch_size': 192,
    'device': "cuda",
    'epochs': 20,
    'lr_max': 4e-4,
    'seed': 42,
    'patience': 6,
    'ema_decay': 0.999,

    # --- Best Params found ---
    'gamma_neg': 4,
    'gamma_pos': 1.0,
    'clip': 0.03,

    # --- Output ---
    'save_path': 'c99_mfo_balanced_best_ema.pth',
}

# =========================================================
# 1. UTILS: SEED & EMA & IA
# =========================================================
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f" Seed set to {seed}")

class ModelEMA:
    def __init__(self, model, decay=0.999):
        # Clone structure
        self.module = copy.deepcopy(model).eval()

        # Load weights
        self.module.load_state_dict(model.state_dict())

        # EMA decay
        self.decay = decay

        # Không track gradient
        for p in self.module.parameters():
            p.requires_grad_(False)

    @torch.no_grad()
    def update(self, model):
        msd = model.state_dict()
        for k, ema_v in self.module.state_dict().items():
            model_v = msd[k].detach()
            ema_v.copy_(ema_v * self.decay + (1.0 - self.decay) * model_v)

def load_ia_weights(vocab_path, ia_path):
    print(" Loading IA Weights...")
    try:
        vocab_df = pd.read_csv(vocab_path)
        term_list = vocab_df['term'].tolist()
        term_to_idx = {t: i for i, t in enumerate(term_list)}
        num_classes = len(term_list)

        ia_weights = np.ones(num_classes, dtype=np.float32)
        if os.path.exists(ia_path):
            with open(ia_path, 'r') as f:
                for line in f:
                    parts = line.strip().split('\t')
                    if len(parts) >= 2:
                        term, ia = parts[0], float(parts[1])
                        if term in term_to_idx:
                            ia_weights[term_to_idx[term]] = ia
            print(f"    Loaded IA for {num_classes} terms.")
        else:
            print(f"    IA file not found. Using weights=1.0")
        return ia_weights
    except Exception as e:
        print(f"    Error loading IA: {e}")
        return np.ones(1, dtype=np.float32)

class ProteinDataset(Dataset):
    def __init__(self, ids_file, targets_pkl, embed_dir, taxon_pkl):
        self.ids = np.load(ids_file)
        self.embeds = np.load(os.path.join(embed_dir, "train_embeds.npy"))

        with open(os.path.join(embed_dir, "train_ids.txt")) as f:
            all_ids = [x.strip() for x in f]
        self.id_to_idx = {pid: i for i, pid in enumerate(all_ids)}

        with open(targets_pkl, 'rb') as f:
            self.targets = pickle.load(f)

        max_idx = 0
        for v in self.targets.values():
            if v: max_idx = max(max_idx, max(v))
        self.num_classes = max_idx + 1

        with open(taxon_pkl, 'rb') as f:
            tax_data = pickle.load(f)

        self.prot_to_taxon = tax_data["prot_to_taxon_idx"]
        self.num_taxa = tax_data["num_taxa_classes"]
        self.default_tax = self.num_taxa - 1  # UNK taxon

    def __len__(self): return len(self.ids)
    def __getitem__(self, i):
        pid = self.ids[i]

        # Embedding
        emb_idx = self.id_to_idx.get(pid)
        feat = torch.tensor(self.embeds[emb_idx], dtype=torch.float32) if emb_idx is not None else torch.zeros(1280)

        target = torch.zeros(self.num_classes, dtype=torch.float32)
        if pid in self.targets:
            inds = self.targets[pid]
            if len(inds) > 0:
                inds = [int(x) for x in inds]
                target[inds] = 1.0

        tax_id = self.prot_to_taxon.get(pid, self.default_tax)
        tax_id = torch.tensor(tax_id, dtype=torch.long)

        return feat, tax_id, target

# =========================================================
# 2. MODEL & LOSS
# =========================================================
class AsymmetricLossOptimized(nn.Module):
    def __init__(self, gamma_neg=4, gamma_pos=0, clip=0.05, eps=1e-7):
        super().__init__()
        self.gamma_neg = gamma_neg
        self.gamma_pos = gamma_pos
        self.clip = clip
        self.eps = eps

    def forward(self, x, y):
        x = x.float()
        y = y.float()

        x_sigmoid = torch.sigmoid(x)
        xs_pos = x_sigmoid
        xs_neg = 1 - x_sigmoid
        if self.clip > 0:
            xs_neg = (xs_neg + self.clip + self.eps).clamp(max=1.0)
        else:
            xs_neg = (xs_neg + self.eps).clamp(max=1.0)

        pt = y * xs_pos + (1 - y) * xs_neg
        log_pt = torch.log(pt.clamp(min=self.eps, max=1.0))

        pos_weight = (1 - xs_pos) ** self.gamma_pos
        neg_weight = (1 - xs_neg) ** self.gamma_neg

        weighted_loss = - (pos_weight * log_pt * y + neg_weight * log_pt * (1-y))
        return weighted_loss.sum() / x.size(0)

class WideProteinMLP_WithTaxon(nn.Module):
    def __init__(self, input_dim, num_classes, num_taxa, taxon_dim=64,
                 hidden_dims=[4096, 4096], dropout=0.4):
        super().__init__()

        # Normalize protein embedding
        self.seq_norm = nn.LayerNorm(input_dim)

        # Taxon branch
        self.taxon_embedding = nn.Embedding(num_taxa, taxon_dim)
        nn.init.normal_(self.taxon_embedding.weight, mean=0.0, std=0.1)
        self.taxon_norm = nn.LayerNorm(taxon_dim)

        # UNK taxon
        self.unk_idx = num_taxa - 1
        with torch.no_grad():
            self.taxon_embedding.weight[self.unk_idx].zero_()

        self.register_buffer(
            "unk_vec",
            self.taxon_embedding.weight[self.unk_idx].clone().detach()
        )

        # Combined feature dimension
        combined_dim = input_dim + taxon_dim

        layers = []
        prev = combined_dim

        for h in hidden_dims:
            layers.append(nn.Linear(prev, h))
            layers.append(nn.GELU())
            layers.append(nn.Dropout(dropout))
            prev = h

        layers.append(nn.Linear(prev, num_classes))

        self.net = nn.Sequential(*layers)

    def forward(self, seq_emb, taxon_id):
        seq = self.seq_norm(seq_emb)
        tax = self.taxon_norm(self.taxon_embedding(taxon_id))
        x = torch.cat([seq, tax], dim=1)
        return self.net(x)

def gpu_calculate_fmax(preds_t, targets_t, ia_weights_np, device):
    if preds_t.numel() == 0: return 0.0
    ia_t = torch.from_numpy(ia_weights_np.astype(np.float32)).to(device)
    w = ia_t.unsqueeze(0)
    true_sum = (targets_t * w).sum(dim=1)
    valid_mask = true_sum > 0
    if valid_mask.sum().item() == 0: return 0.0

    p_sub = preds_t[valid_mask]
    t_sub = targets_t[valid_mask]
    w_sub = w
    w_true_sub = true_sum[valid_mask]

    best_f1 = 0.0
    thresholds = torch.linspace(0.0, 1.0, 51, device=device)
    for tau in thresholds:
        cut = (p_sub >= tau).float()
        tp = ((cut * t_sub) * w_sub).sum(dim=1)
        pred_sum = (cut * w_sub).sum(dim=1)
        prec = torch.where(pred_sum != 0, tp / pred_sum, torch.zeros_like(tp))
        rec = torch.where(w_true_sub != 0, tp / w_true_sub, torch.zeros_like(tp))
        avg_p = prec.mean(); avg_r = rec.mean()
        denom = (avg_p + avg_r)
        if denom > 0: f1 = (2.0 * avg_p * avg_r / denom).item()
        else: f1 = 0.0
        if f1 > best_f1: best_f1 = f1
    return float(best_f1)

def gpu_fmax_split(preds_t, targets_t, ia_t):
    """
    preds_t: Tensor [N, C]
    targets_t: Tensor [N, C]
    ia_t: Tensor [C]
    """
    device = preds_t.device
    w = ia_t.unsqueeze(0)
    true_sum = (targets_t * w).sum(1)

    valid = true_sum > 0
    if valid.sum() == 0:
        return 0.0

    p = preds_t[valid]
    t = targets_t[valid]
    ts = true_sum[valid]

    best = 0.0
    thresholds = torch.linspace(0, 1, 101, device=device)

    for tau in thresholds:
        cut = (p >= tau).float()

        tp = (cut * t * w).sum(1)
        pred_sum = (cut * w).sum(1)

        prec = torch.where(pred_sum > 0, tp / pred_sum, torch.zeros_like(tp))
        rec = torch.where(ts > 0, tp / ts, torch.zeros_like(tp))

        avg_p = prec.mean()
        avg_r = rec.mean()
        denom = avg_p + avg_r

        f1 = torch.where(denom > 0, 2 * avg_p * avg_r / denom, torch.tensor(0., device=device))
        best = max(best, f1.item())

    return best

# =========================================================
# 3. TRAINING LOOP
# =========================================================
def train_epoch(model, ema_model, loader, loss_fn, optimizer, scaler, scheduler, device):
    model.train()
    running_loss = 0.0
    pbar = tqdm(loader, desc="Training")

    for feats, tax_id, labels in pbar:
        feats = feats.to(device)
        tax_id = tax_id.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        with autocast(device_type="cuda", dtype=torch.bfloat16):
            logits = model(feats, tax_id)
            loss = loss_fn(logits, labels)

        # immediate check (loss may be a tensor on cuda)
        if not torch.isfinite(loss):
            print("!!! LOSS NaN or Inf detected.")
            print("max_logit", logits.max().item(), "min_logit", logits.min().item())
            torch.save(model.state_dict(), "crash_before_nan.pth")
            raise RuntimeError("NaN in loss")

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)

        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        # Update EMA
        if ema_model is not None:
            ema_model.update(model)

        running_loss += float(loss.item())
        # ==========================================
        # GPU MEMORY CLEANUP — PREVENT MEMORY CREEP
        # ==========================================
        del feats, tax_id, labels, logits, loss
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()

    return running_loss / len(loader)

@torch.no_grad()
def validate_split_go_gpu(model, loader, loss_fn, device, ia_weights, vocab_df):
    model.eval()

    preds_list = []
    targets_list = []
    total_loss = 0.0

    for feats, tax_id, labels in loader:
        feats = feats.to(device)
        tax_id = tax_id.to(device)
        labels = labels.to(device)

        logits = model(feats, tax_id)
        loss = loss_fn(logits, labels)
        total_loss += loss.item()

        preds_list.append(torch.sigmoid(logits))
        targets_list.append(labels)

    preds_t = torch.cat(preds_list, dim=0).to(device)
    targets_t = torch.cat(targets_list, dim=0).to(device)
    val_loss = total_loss / len(loader)

    # Convert IA to torch
    ia_t = torch.tensor(ia_weights, dtype=torch.float32, device=device)

    scores = {}

    for asp in ["MFO", "BPO", "CCO"]:
        cols = vocab_df.index[vocab_df["aspect"] == asp].tolist()
        if not cols:
            continue

        scores[asp] = gpu_fmax_split(
            preds_t[:, cols],
            targets_t[:, cols],
            ia_t[cols]
        )

    avg_fmax = np.mean(list(scores.values()))

    return val_loss, avg_fmax, scores

# =========================================================
# MAIN
# =========================================================
def main():
    # 1. Init Seed
    seed_everything(CONFIG['seed'])

    # 2. Load Data & IA
    ia_weights = load_ia_weights(CONFIG['VOCAB_FILE'], CONFIG['IA_FILE'])
    train_ds = ProteinDataset(CONFIG['TRAIN_IDS'], CONFIG['TARGETS_PKL'], CONFIG['EMBED_DIR'], CONFIG['TAXON_PKL'])
    val_ds    = ProteinDataset(CONFIG['VAL_IDS'], CONFIG['TARGETS_PKL'], CONFIG['EMBED_DIR'], CONFIG['TAXON_PKL'])

    if len(ia_weights) != train_ds.num_classes:
        print(f"Resizing IA: {len(ia_weights)} -> {train_ds.num_classes}")
        new_ia = np.ones(train_ds.num_classes, dtype=np.float32)
        min_len = min(len(ia_weights), train_ds.num_classes)
        new_ia[:min_len] = ia_weights[:min_len]
        ia_weights = new_ia

    CONFIG['output_classes'] = train_ds.num_classes

    train_loader = DataLoader(train_ds, batch_size=CONFIG['batch_size'], shuffle=True,
                              num_workers=16, pin_memory=True, persistent_workers=False, prefetch_factor=2)
    val_loader    = DataLoader(val_ds, batch_size=CONFIG['batch_size'], shuffle=False,
                               num_workers=16, pin_memory=True, persistent_workers=False, prefetch_factor=2)

    # 3. Init Model & EMA
    print(f"START OFFICIAL TRAINING V2 (C95 MFO | Seed: {CONFIG['seed']})")
    model = WideProteinMLP_WithTaxon(
        input_dim=CONFIG['input_dim'],
        num_classes=CONFIG['output_classes'],
        num_taxa=train_ds.num_taxa,
        taxon_dim=64,
        hidden_dims=[4096, 4096],
        dropout=0.25
    ).to(CONFIG['device'])

    #  Initialize EMA Model
    ema_model = ModelEMA(model, decay=CONFIG['ema_decay'])
    print(f"   Using EMA Decay: {CONFIG['ema_decay']}")

    loss_fn = AsymmetricLossOptimized(gamma_neg=CONFIG['gamma_neg'], gamma_pos=CONFIG['gamma_pos'], clip=CONFIG['clip'])
    optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG['lr_max'], weight_decay=0.01)
    scaler = GradScaler()

    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=CONFIG['lr_max'], epochs=CONFIG['epochs'],
        steps_per_epoch=len(train_loader), pct_start=0.3, div_factor=25, final_div_factor=1e4
    )

    # 4. Training Loop
    best_proxy_fmax = 0.0
    patience_counter = 0

    fmax_buffer = deque(maxlen=3)   # window = 3

    vocab_df = pd.read_csv(CONFIG['VOCAB_FILE'])

    for epoch in range(CONFIG['epochs']):
        tr_loss = train_epoch(model, ema_model, train_loader, loss_fn, optimizer, scaler, scheduler, CONFIG['device'])

        #  Validate EMA Model (
        val_loss, val_fmax, go_scores = validate_split_go_gpu(
            ema_model.module,
            val_loader,
            loss_fn,
            CONFIG["device"],
            ia_weights,
            vocab_df
        )

        print(f"Epoch {epoch+1}/{CONFIG['epochs']} | Train Loss: {tr_loss:.4f} | Val Loss: {val_loss:.4f} | EMA F-max: {val_fmax:.4f}")
        print("MF:", go_scores["MFO"], "BP:", go_scores["BPO"], "CC:", go_scores["CCO"])
        #  cập nhật buffer
        fmax_buffer.append(val_fmax)

        if len(fmax_buffer) == fmax_buffer.maxlen:
            proxy_fmax = float(np.mean(fmax_buffer))
        else:
            proxy_fmax = val_fmax

        if proxy_fmax > best_proxy_fmax:
            best_proxy_fmax = proxy_fmax
            patience_counter = 0
            torch.save(ema_model.module.state_dict(), CONFIG['save_path'])
            print(f"    Saved Best EMA Model (proxy F-max: {best_proxy_fmax:.4f})")
        else:
            patience_counter += 1
            print(f"    Patience: {patience_counter}/{CONFIG['patience']}")
            if patience_counter >= CONFIG['patience']:
                print(" Early Stopping Triggered!")
                break

        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()

    print(f"\n Finished. Best EMA proxy F-max: {best_proxy_fmax:.4f}")

if __name__ == "__main__":
    main()

 Seed set to 42
 Loading IA Weights...
    Loaded IA for 15582 terms.
START OFFICIAL TRAINING V2 (C95 MFO | Seed: 42)
   Using EMA Decay: 0.999


Training: 100%|██████████| 344/344 [00:17<00:00, 19.43it/s]


Epoch 1/20 | Train Loss: 67.5758 | Val Loss: 314.3646 | EMA F-max: 0.1247
MF: 0.16382519900798798 BP: 0.09105165302753448 CC: 0.1191985160112381
    Saved Best EMA Model (proxy F-max: 0.1247)


Training: 100%|██████████| 344/344 [00:16<00:00, 20.59it/s]


Epoch 2/20 | Train Loss: 31.9462 | Val Loss: 122.7725 | EMA F-max: 0.3524
MF: 0.4022637903690338 BP: 0.2591497302055359 CC: 0.3959347903728485
    Saved Best EMA Model (proxy F-max: 0.3524)


Training: 100%|██████████| 344/344 [00:16<00:00, 20.76it/s]


Epoch 3/20 | Train Loss: 29.6347 | Val Loss: 51.9845 | EMA F-max: 0.4178
MF: 0.45802152156829834 BP: 0.30861836671829224 CC: 0.4866268038749695
    Patience: 1/6


Training: 100%|██████████| 344/344 [00:16<00:00, 20.64it/s]


Epoch 4/20 | Train Loss: 28.3631 | Val Loss: 34.8537 | EMA F-max: 0.4507
MF: 0.4955832064151764 BP: 0.3308239281177521 CC: 0.5255696177482605
    Saved Best EMA Model (proxy F-max: 0.4070)


Training: 100%|██████████| 344/344 [00:16<00:00, 20.52it/s]


Epoch 5/20 | Train Loss: 27.3542 | Val Loss: 30.6259 | EMA F-max: 0.4744
MF: 0.5242636203765869 BP: 0.3494601845741272 CC: 0.54942387342453
    Saved Best EMA Model (proxy F-max: 0.4476)


Training: 100%|██████████| 344/344 [00:16<00:00, 20.70it/s]


Epoch 6/20 | Train Loss: 26.3936 | Val Loss: 28.9254 | EMA F-max: 0.4927
MF: 0.5488895177841187 BP: 0.3651718199253082 CC: 0.5639460682868958
    Saved Best EMA Model (proxy F-max: 0.4726)


Training: 100%|██████████| 344/344 [00:17<00:00, 20.16it/s]


Epoch 7/20 | Train Loss: 25.3263 | Val Loss: 27.9263 | EMA F-max: 0.5055
MF: 0.5656755566596985 BP: 0.378062903881073 CC: 0.572760283946991
    Saved Best EMA Model (proxy F-max: 0.4909)


Training: 100%|██████████| 344/344 [00:17<00:00, 19.20it/s]


Epoch 8/20 | Train Loss: 24.1579 | Val Loss: 27.3184 | EMA F-max: 0.5161
MF: 0.5797516703605652 BP: 0.3887706398963928 CC: 0.5798185467720032
    Saved Best EMA Model (proxy F-max: 0.5048)


Training: 100%|██████████| 344/344 [00:19<00:00, 17.44it/s]


Epoch 9/20 | Train Loss: 22.9604 | Val Loss: 27.0002 | EMA F-max: 0.5238
MF: 0.5891242623329163 BP: 0.3962215483188629 CC: 0.5859778523445129
    Saved Best EMA Model (proxy F-max: 0.5151)


Training: 100%|██████████| 344/344 [00:17<00:00, 20.18it/s]


Epoch 10/20 | Train Loss: 21.7375 | Val Loss: 26.9346 | EMA F-max: 0.5288
MF: 0.5949710607528687 BP: 0.40131181478500366 CC: 0.5900393724441528
    Saved Best EMA Model (proxy F-max: 0.5229)


Training: 100%|██████████| 344/344 [00:17<00:00, 20.03it/s]


Epoch 11/20 | Train Loss: 20.4698 | Val Loss: 27.1004 | EMA F-max: 0.5313
MF: 0.5973190665245056 BP: 0.4045393466949463 CC: 0.5919551253318787
    Saved Best EMA Model (proxy F-max: 0.5279)


Training: 100%|██████████| 344/344 [00:17<00:00, 20.02it/s]


Epoch 12/20 | Train Loss: 19.1877 | Val Loss: 27.4687 | EMA F-max: 0.5328
MF: 0.598224401473999 BP: 0.4061540365219116 CC: 0.5939397215843201
    Saved Best EMA Model (proxy F-max: 0.5309)


Training: 100%|██████████| 344/344 [00:16<00:00, 20.24it/s]


Epoch 13/20 | Train Loss: 17.9695 | Val Loss: 28.0098 | EMA F-max: 0.5321
MF: 0.5978807806968689 BP: 0.4047158360481262 CC: 0.5936333537101746
    Saved Best EMA Model (proxy F-max: 0.5320)


Training: 100%|██████████| 344/344 [00:17<00:00, 19.56it/s]


Epoch 14/20 | Train Loss: 16.7903 | Val Loss: 28.7036 | EMA F-max: 0.5303
MF: 0.5948261618614197 BP: 0.40303143858909607 CC: 0.59291672706604
    Patience: 1/6


Training: 100%|██████████| 344/344 [00:17<00:00, 19.45it/s]


Epoch 15/20 | Train Loss: 15.7645 | Val Loss: 29.5324 | EMA F-max: 0.5281
MF: 0.5918530225753784 BP: 0.4006190896034241 CC: 0.5917574763298035
    Patience: 2/6


Training: 100%|██████████| 344/344 [00:17<00:00, 19.41it/s]


Epoch 16/20 | Train Loss: 14.8815 | Val Loss: 30.4301 | EMA F-max: 0.5257
MF: 0.5886651873588562 BP: 0.3979606330394745 CC: 0.590548574924469
    Patience: 3/6


Training: 100%|██████████| 344/344 [00:18<00:00, 18.84it/s]


Epoch 17/20 | Train Loss: 14.2113 | Val Loss: 31.3588 | EMA F-max: 0.5240
MF: 0.5860811471939087 BP: 0.3961750566959381 CC: 0.5897903442382812
    Patience: 4/6


Training: 100%|██████████| 344/344 [00:16<00:00, 20.49it/s]


Epoch 18/20 | Train Loss: 13.6950 | Val Loss: 32.2558 | EMA F-max: 0.5220
MF: 0.5838826894760132 BP: 0.39440321922302246 CC: 0.587813675403595
    Patience: 5/6


Training: 100%|██████████| 344/344 [00:16<00:00, 20.52it/s]


Epoch 19/20 | Train Loss: 13.4050 | Val Loss: 33.0437 | EMA F-max: 0.5205
MF: 0.5818638801574707 BP: 0.39279839396476746 CC: 0.5867449045181274
    Patience: 6/6
 Early Stopping Triggered!

 Finished. Best EMA proxy F-max: 0.5320


In [None]:


# =========================================================
# CONFIG (CHẾ ĐỘ FULL TRAIN)
# =========================================================
CONFIG = {
    # --- Paths ---
    'TRAIN_IDS': '/root/CAFA6data/c99/train_ids_C99_split.npy',
    'VAL_IDS': '/root/CAFA6data/c99/val_ids_C99_split.npy',
    'TARGETS_PKL': '/root/CAFA6data/c99/train_targets_C99.pkl',
    'EMBED_DIR': '/root/CAFA6data/cafa6-embeds',
    'IA_FILE': '/root/CAFA6data/IA.tsv',
    'VOCAB_FILE': '/root/CAFA6data/c99/vocab_C99_remove.csv',
    'TAXON_PKL': '/root/cafa6/preprocessing/taxon_mapping_K_Species.pkl',

    # --- Training Hypers ---
    'input_dim': 1280,
    'batch_size': 192,
    'device': "cuda",
    'lr_max': 4e-4,
    'seed': 42,
    'ema_decay': 0.999,

    'epochs': 20,
    'stop_epoch': 13,

    # --- Best Params found ---
    'gamma_neg': 4,
    'gamma_pos': 1.0,
    'clip': 0.03,

}

# =========================================================
# 1. UTILS
# =========================================================
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f" Seed set to {seed}")

class ModelEMA:
    def __init__(self, model, decay=0.999):
        self.module = copy.deepcopy(model).eval()
        self.module.load_state_dict(model.state_dict())
        self.decay = decay
        for p in self.module.parameters(): p.requires_grad_(False)

    @torch.no_grad()
    def update(self, model):
        msd = model.state_dict()
        for k, ema_v in self.module.state_dict().items():
            model_v = msd[k].detach()
            ema_v.copy_(ema_v * self.decay + (1.0 - self.decay) * model_v)

def load_ia_weights(vocab_path, ia_path):
    print("Loading IA Weights...")
    try:
        vocab_df = pd.read_csv(vocab_path)
        term_list = vocab_df['term'].tolist()
        term_to_idx = {t: i for i, t in enumerate(term_list)}
        num_classes = len(term_list)
        ia_weights = np.ones(num_classes, dtype=np.float32)
        if os.path.exists(ia_path):
            with open(ia_path, 'r') as f:
                for line in f:
                    parts = line.strip().split('\t')
                    if len(parts) >= 2:
                        term, ia = parts[0], float(parts[1])
                        if term in term_to_idx:
                            ia_weights[term_to_idx[term]] = ia
        return ia_weights
    except Exception:
        return np.ones(1, dtype=np.float32)

class ProteinDataset(Dataset):
    def __init__(self, ids_source, targets_pkl, embed_dir, taxon_pkl):

        if isinstance(ids_source, str):
            self.ids = np.load(ids_source)
        else:
            self.ids = ids_source

        # Dùng mmap_mode='r' để tiết kiệm RAM
        self.embeds = np.load(os.path.join(embed_dir, "train_embeds.npy"), mmap_mode='r')

        with open(os.path.join(embed_dir, "train_ids.txt")) as f:
            all_ids = [x.strip() for x in f]
        self.id_to_idx = {pid: i for i, pid in enumerate(all_ids)}

        with open(targets_pkl, 'rb') as f:
            self.targets = pickle.load(f)

        max_idx = 0
        for v in self.targets.values():
            if v: max_idx = max(max_idx, max(v))
        self.num_classes = max_idx + 1

        with open(taxon_pkl, 'rb') as f:
            tax_data = pickle.load(f)
        self.prot_to_taxon = tax_data["prot_to_taxon_idx"]
        self.num_taxa = tax_data["num_taxa_classes"]
        self.default_tax = self.num_taxa - 1

    def __len__(self): return len(self.ids)
    def __getitem__(self, i):
        pid = self.ids[i]
        emb_idx = self.id_to_idx.get(pid)
        if emb_idx is not None:
             # Copy vào RAM
            feat = torch.tensor(self.embeds[emb_idx], dtype=torch.float32)
        else:
            feat = torch.zeros(1280)

        target = torch.zeros(self.num_classes, dtype=torch.float32)
        if pid in self.targets:
            inds = self.targets[pid]
            if len(inds) > 0:
                target[inds] = 1.0

        tax_id = self.prot_to_taxon.get(pid, self.default_tax)
        return feat, torch.tensor(tax_id, dtype=torch.long), target

# =========================================================
# 2. MODEL & LOSS
# =========================================================
class AsymmetricLossOptimized(nn.Module):
    def __init__(self, gamma_neg=4, gamma_pos=0, clip=0.05, eps=1e-7):
        super().__init__()
        self.gamma_neg = gamma_neg
        self.gamma_pos = gamma_pos
        self.clip = clip
        self.eps = eps
    def forward(self, x, y):
        x = x.float(); y = y.float()
        x_sigmoid = torch.sigmoid(x)
        xs_pos = x_sigmoid
        xs_neg = (1 - x_sigmoid + self.clip + self.eps).clamp(max=1.0)
        pt = y * xs_pos + (1 - y) * xs_neg
        log_pt = torch.log(pt.clamp(min=self.eps, max=1.0))
        weight = y*(1-xs_pos)**self.gamma_pos + (1-y)*(1-xs_neg)**self.gamma_neg
        loss = - weight * log_pt
        return loss.sum() / x.size(0)

class WideProteinMLP_WithTaxon(nn.Module):
    def __init__(self, input_dim, num_classes, num_taxa, taxon_dim=64, hidden_dims=[4096, 4096], dropout=0.25):
        super().__init__()
        self.seq_norm = nn.LayerNorm(input_dim)
        self.taxon_embedding = nn.Embedding(num_taxa, taxon_dim)
        self.taxon_norm = nn.LayerNorm(taxon_dim)
        with torch.no_grad(): self.taxon_embedding.weight[num_taxa-1].zero_()
        layers = []
        prev = input_dim + taxon_dim
        for h in hidden_dims:
            layers.append(nn.Linear(prev, h))
            layers.append(nn.GELU())
            layers.append(nn.Dropout(dropout))
            prev = h
        layers.append(nn.Linear(prev, num_classes))
        self.net = nn.Sequential(*layers)
    def forward(self, seq, tax):
        x = torch.cat([self.seq_norm(seq), self.taxon_norm(self.taxon_embedding(tax))], dim=1)
        return self.net(x)

# =========================================================
# METRICS & TRAINING
# =========================================================
def train_epoch(model, ema_model, loader, loss_fn, optimizer, scaler, scheduler, device):
    model.train()
    running_loss = 0.0
    pbar = tqdm(loader, desc="Training Full")

    for feats, tax_id, labels in pbar:
        feats = feats.to(device)
        tax_id = tax_id.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        with autocast(device_type="cuda", dtype=torch.bfloat16):
            logits = model(feats, tax_id)
            loss = loss_fn(logits, labels)

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        scaler.step(optimizer); scaler.update(); scheduler.step()

        if ema_model: ema_model.update(model)

        running_loss += float(loss.item())
        del feats, tax_id, labels, logits, loss

    return running_loss / len(loader)

# =========================================================
# MAIN
# =========================================================
def main():
    seed_everything(CONFIG['seed'])

    print(" Merging Train + Val data for FULL TRAINING...")
    ids_train = np.load(CONFIG['TRAIN_IDS'])
    ids_val   = np.load(CONFIG['VAL_IDS'])
    ids_full  = np.concatenate([ids_train, ids_val])

    full_ds = ProteinDataset(ids_full, CONFIG['TARGETS_PKL'], CONFIG['EMBED_DIR'], CONFIG['TAXON_PKL'])
    full_loader = DataLoader(full_ds, batch_size=CONFIG['batch_size'], shuffle=True,
                            num_workers=16, pin_memory=True, persistent_workers=False)

    # Load IA Weights cho metric
    ia_weights = load_ia_weights(CONFIG['VOCAB_FILE'], CONFIG['IA_FILE'])
    if len(ia_weights) != full_ds.num_classes:
        new_ia = np.ones(full_ds.num_classes, dtype=np.float32)
        min_len = min(len(ia_weights), full_ds.num_classes)
        new_ia[:min_len] = ia_weights[:min_len]
        ia_weights = new_ia

    model = WideProteinMLP_WithTaxon(
        input_dim=CONFIG['input_dim'],
        num_classes=full_ds.num_classes,
        num_taxa=full_ds.num_taxa,
        taxon_dim=64,
        hidden_dims=[4096, 4096],
        dropout=0.25
    ).to(CONFIG['device'])

    ema_model = ModelEMA(model, decay=CONFIG['ema_decay'])

    loss_fn = AsymmetricLossOptimized(
        gamma_neg=CONFIG['gamma_neg'],
        gamma_pos=CONFIG['gamma_pos'],
        clip=CONFIG['clip']
    )

    optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG['lr_max'], weight_decay=0.01)
    scaler = GradScaler()

    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=CONFIG['lr_max'],  # same LR as training
        epochs=CONFIG['epochs'],                # BEST EPOCH FOUND
        steps_per_epoch=len(full_loader),
        pct_start=0.3,
        div_factor=25,
        final_div_factor=1e4
    )

    print("FINAL TRAINING START (full data, 20 epochs)")

    for epoch in range(CONFIG['epochs']):
        tr_loss = train_epoch(
        model, ema_model, full_loader,
        loss_fn, optimizer, scaler,
        scheduler, CONFIG['device']
    )

        current_lr = scheduler.get_last_lr()[0]
        print(f"[Full] Epoch {epoch+1}/{CONFIG['epochs']} | Loss: {tr_loss:.4f} | LR: {current_lr:.2e}")

        # ===============================
        # FORCE STOP AT EPOCH 13
        # ===============================
        if epoch + 1 >= CONFIG['stop_epoch']:
            print(f" Forced stop at epoch {epoch+1} (best F-max observed)")
            break

    torch.save(ema_model.module.state_dict(), "final_cafa6_model_c99.pth")
    print(" FINAL MODEL SAVED: final_cafa6_model_c99.pth")

if __name__ == "__main__":
    main()

 Seed set to 42
 Merging Train + Val data for FULL TRAINING...
Loading IA Weights...
FINAL TRAINING START (full data, 20 epochs)


Training Full: 100%|██████████| 430/430 [00:16<00:00, 26.42it/s]


[Full] Epoch 1/20 | Loss: 60.7082 | LR: 4.17e-05


Training Full: 100%|██████████| 430/430 [00:16<00:00, 26.59it/s]


[Full] Epoch 2/20 | Loss: 31.3414 | LR: 1.12e-04


Training Full: 100%|██████████| 430/430 [00:16<00:00, 26.50it/s]


[Full] Epoch 3/20 | Loss: 29.2087 | LR: 2.08e-04


Training Full: 100%|██████████| 430/430 [00:16<00:00, 26.76it/s]


[Full] Epoch 4/20 | Loss: 27.9943 | LR: 3.04e-04


Training Full: 100%|██████████| 430/430 [00:16<00:00, 26.24it/s]


[Full] Epoch 5/20 | Loss: 27.0426 | LR: 3.74e-04


Training Full: 100%|██████████| 430/430 [00:16<00:00, 26.49it/s]


[Full] Epoch 6/20 | Loss: 26.1325 | LR: 4.00e-04


Training Full: 100%|██████████| 430/430 [00:16<00:00, 26.33it/s]


[Full] Epoch 7/20 | Loss: 25.0885 | LR: 3.95e-04


Training Full: 100%|██████████| 430/430 [00:16<00:00, 26.45it/s]


[Full] Epoch 8/20 | Loss: 23.9432 | LR: 3.80e-04


Training Full: 100%|██████████| 430/430 [00:16<00:00, 26.24it/s]


[Full] Epoch 9/20 | Loss: 22.7908 | LR: 3.56e-04


Training Full: 100%|██████████| 430/430 [00:16<00:00, 25.31it/s]


[Full] Epoch 10/20 | Loss: 21.7979 | LR: 3.25e-04


Training Full: 100%|██████████| 430/430 [00:16<00:00, 26.15it/s]


[Full] Epoch 11/20 | Loss: 20.2878 | LR: 2.87e-04


Training Full: 100%|██████████| 430/430 [00:16<00:00, 25.71it/s]


[Full] Epoch 12/20 | Loss: 19.0122 | LR: 2.44e-04


Training Full: 100%|██████████| 430/430 [00:16<00:00, 26.43it/s]


[Full] Epoch 13/20 | Loss: 17.7908 | LR: 2.00e-04
 Forced stop at epoch 13 (best F-max observed)
 FINAL MODEL SAVED: final_cafa6_model_c99.pth
