In [None]:
# ------------------------------------------------------------
# 1. Imports
# ------------------------------------------------------------
import os, gc, math, random, sys, time
from collections import defaultdict, Counter
from typing import List, Dict, Tuple, Set
import numpy as np
import pandas as pd
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow.keras import layers, models, callbacks, optimizers, losses, metrics
from collections import defaultdict

print("Import done!!!")

In [None]:
# ------------------------------------------------------------
# 2. TSV FILES
# ------------------------------------------------------------
# CẤU HÌNH ĐƯỜNG DẪN
SAMPLE_SUBMISSION_TSV = "/kaggle/input/cafa-6-protein-function-prediction/sample_submission.tsv"
IA_TSV = "/kaggle/input/cafa-6-protein-function-prediction/IA.tsv"
TRAIN_TERMS_TSV = "/kaggle/input/cafa-6-protein-function-prediction/Train/train_terms.tsv"
TRAIN_TAXONOMY_TSV = "/kaggle/input/cafa-6-protein-function-prediction/Train/train_taxonomy.tsv"
TESTSUPERSET_FASTA = "/kaggle/input/cafa-6-protein-function-prediction/Test/testsuperset.fasta"
TRAIN_SEQUENCES_FASTA = "/kaggle/input/cafa-6-protein-function-prediction/Train/train_sequences.fasta"

# OBO FILE
GO_BASIC_OBO = "/kaggle/input/cafa-6-protein-function-prediction/Train/go-basic.obo"

# PRE-COMPUTED EMBEDDINGS 
TRAIN_EMBEDS_NPY = "/kaggle/input/embedding-esm2-650m/biggest embedding/train_embeddings_650M.npy"
TRAIN_IDS_NPY = "/kaggle/input/embedding-esm2-650m/biggest embedding/train_ids.npy"
TEST_EMBEDS_NPY = "/kaggle/input/embedding-esm2-650m/biggest embedding/test_embeddings_650M.npy"
TEST_IDS_NPY = "/kaggle/input/embedding-esm2-650m/biggest embedding/test_ids.npy"

# OUTPUT FILE
OUTPUT_TSV = "/kaggle/working/submission.tsv"

print("Files are listed!!!")

In [None]:
# ------------------------------------------------------------
# 3. CONFIG
# ------------------------------------------------------------
CONFIG = {
    "TRAIN_TERMS": TRAIN_TERMS_TSV,
    "GO_OBO": GO_BASIC_OBO,
    "IA_FILE": IA_TSV,
    "OUTPUT_SUBMISSION": OUTPUT_TSV,
    "TRAIN_EMBEDS": TRAIN_EMBEDS_NPY,
    "TRAIN_IDS": TRAIN_IDS_NPY,
    "TEST_EMBEDS": TEST_EMBEDS_NPY,
    "TEST_IDS": TEST_IDS_NPY,

    "PREDICT_BATCH_SIZE": 4096,
    
    
    
    # THAM SỐ MODEL
    "RANDOM_SEED": 42,
    

    "BATCH_SIZE": 256,
    "TOP_K_LABELS": 2000,
    "EPOCHS": 70,
    "LEARNING_RATE": 5e-4,
    "HIDDEN_UNITS": 2048,
    "DROPOUT": 0.2,
    
    # THAM SỐ CHO PHẦN SUBMIT
    "TOP_K_PER_PROTEIN": 150,
    "GLOBAL_THRESHOLD_SEARCH": True,
    "THRESHOLD_GRID": [i/100 for i in range(1, 51)],
    
    # LAN TRUYỀN
    "PROPAGATE_TRAIN_LABELS": True,
    "PROPAGATE_PREDICTIONS": True,
}

print("Config loaded...")
random.seed(CONFIG["RANDOM_SEED"])
np.random.seed(CONFIG["RANDOM_SEED"])
tf.random.set_seed(CONFIG["RANDOM_SEED"])

In [None]:
# ------------------------------------------------------------
# 4. Utils for TSV/OBO parsing 
# ------------------------------------------------------------
def read_train_terms(path):
    df = pd.read_csv(path, sep="\t", header=None, names=["protein", "go", "ont"], dtype=str)
    mapping = df.groupby("protein")["go"].apply(list).to_dict()
    print(f"[io] Read training annotations for {len(mapping)} proteins from {path}")
    return mapping

def parse_obo(go_obo_path: str) -> Tuple[Dict[str, Set[str]], Dict[str, Set[str]]]:
    parents = defaultdict(set); children = defaultdict(set)
    if not os.path.exists(go_obo_path): return parents, children
    with open(go_obo_path,"r") as f:
        cur_id=None
        for line in f:
            line=line.strip()
            if line=="[Term]": cur_id=None
            elif line.startswith("id: "): cur_id=line.split("id: ")[1].strip()
            elif line.startswith("is_a: "):
                pid=line.split()[1].strip()
                if cur_id: parents[cur_id].add(pid); children[pid].add(cur_id)
            elif line.startswith("relationship: part_of "):
                parts=line.split(); 
                if len(parts)>=3:
                    pid=parts[2].strip()
                    if cur_id: parents[cur_id].add(pid); children[pid].add(cur_id)
    print(f"[io] Parsed OBO: {len(parents)} nodes with parents")
    return parents, children

def get_ancestors(go_id: str, parents: Dict[str, Set[str]]) -> Set[str]:
    ans=set(); stack=[go_id]
    while stack:
        cur=stack.pop()
        for p in parents.get(cur,[]): 
            if p not in ans:
                ans.add(p); stack.append(p)
    return ans

def read_IA(path: str) -> pd.DataFrame:
    df = pd.read_csv(path, sep="\t", header=None, names=["go", "score"])
    df["score"] = df["score"].astype(float)
    return df

In [None]:
# ------------------------------------------------------------
# 5. Load Metadata (Terms, OBO, IA)
# ------------------------------------------------------------
train_terms = read_train_terms(CONFIG["TRAIN_TERMS"])
parents_map, children_map = parse_obo(CONFIG["GO_OBO"])

In [None]:
# ------------------------------------------------------------
# 6. Load PRE-COMPUTED TRAIN EMBEDDINGS
# ------------------------------------------------------------
print("[io] Loading pre-computed TRAIN embeddings...")
# LOAD EMBEDDING 
train_emb_full = np.load(CONFIG["TRAIN_EMBEDS"])
train_ids_full = np.load(CONFIG["TRAIN_IDS"])
print("ID trước khi sửa:", train_ids_full[0])
train_ids_full = pd.Series(train_ids_full).apply(lambda x: x.split('|')[1]).values
print("ID sau khi sửa:", train_ids_full[0])
print(f"[io] Loaded train embeddings shape: {train_emb_full.shape}")
print(f"[io] Loaded train ids shape: {train_ids_full.shape}")

# GIỮ LẠI CÁC PROTEIN GIAO NHAU (ĐÔI KHI EMBEDDING CÓ SỐ DÒNG LỚN HƠN)
valid_indices = []
valid_ids = []
for idx, pid in enumerate(train_ids_full):
    if pid in train_terms:
        valid_indices.append(idx)
        valid_ids.append(pid)

X_emb = train_emb_full[valid_indices]
X_proteins = valid_ids
print(f"[prep] Filtered to {len(X_proteins)} proteins that have labels.")

del train_emb_full, train_ids_full
gc.collect()

In [None]:
# ------------------------------------------------------------
# 7. Label Processing (IA Scores & Propagation)
# ------------------------------------------------------------
# LAN TRUYỀN NHÃN
if CONFIG["PROPAGATE_TRAIN_LABELS"] and parents_map:
    print("[prep] Propagating train labels up GO graph")
    propagated={}
    for p in X_proteins:
        terms=set(train_terms[p])
        extra=set()
        for t in list(terms): extra |= get_ancestors(t, parents_map)
        propagated[p]=sorted(terms|extra)
    for p in X_proteins:
        train_terms[p] = propagated[p]

IA_list = read_IA(CONFIG["IA_FILE"])
term_freq = Counter()
for p in X_proteins:
    term_freq.update(train_terms[p])

IA_dict = dict(zip(IA_list["go"], IA_list["score"]))
term_combined_score = {}
for t in term_freq:
    ia_score = IA_dict.get(t, 0)
    term_combined_score[t] = ia_score * term_freq[t]

if CONFIG["TOP_K_LABELS"] is not None:
    top_terms_sorted = sorted(term_combined_score.items(), key=lambda x: x[1], reverse=True)
    chosen_terms = set([t for t, _ in top_terms_sorted[:CONFIG["TOP_K_LABELS"]]])
    print(f"[prep] Restricting to top-{CONFIG['TOP_K_LABELS']} GO terms")
else:
    chosen_terms = set(term_combined_score.keys())

# NHÃN Y
y_labels = [[t for t in train_terms[p] if t in chosen_terms] for p in X_proteins]
mlb = MultiLabelBinarizer(classes=sorted(chosen_terms))
Y = mlb.fit_transform(y_labels).astype(np.float32)
print("[prep] Label matrix shape:", Y.shape)

In [None]:
# ------------------------------------------------------------
# 8. Train / validation split
# ------------------------------------------------------------
y = Y
X_tr, X_val, y_tr, y_val = train_test_split(
    X_emb, y, test_size=0.1, random_state=CONFIG["RANDOM_SEED"]
)
print("[train] shapes:", X_tr.shape, X_val.shape, y_tr.shape, y_val.shape)

In [None]:
# ------------------------------------------------------------
# 9. ĐỊNH NGHĨA HÀM FOCAL LOSS
# ------------------------------------------------------------
def focal_loss_fixed(gamma, alpha):
    def focal_loss_fn(y_true, y_pred):
        # KIỂU DỮ LIỆU FLOAT 32
        y_true = tf.cast(y_true, tf.float32)
        y_pred = tf.cast(y_pred, tf.float32)
        
        epsilon = tf.keras.backend.epsilon()
        y_pred = tf.clip_by_value(y_pred, epsilon, 1. - epsilon)
        
        # Tính Focal Loss để giảm ảnh hưởng nhãn âm
        # Công thức: -alpha * (1 - p)^gamma * log(p) 
        loss = alpha * tf.pow(1 - y_pred, gamma) * y_true * tf.math.log(y_pred) \
             + (1 - alpha) * tf.pow(y_pred, gamma) * (1 - y_true) * tf.math.log(1 - y_pred)
             
        # Trả về giá trị dương (trị tuyệt đối)
        return -tf.reduce_mean(loss, axis=-1)
        
    return focal_loss_fn

In [None]:
# ------------------------------------------------------------
# 10. ResNet MLP
# ------------------------------------------------------------
def build_resnet_model(input_dim, output_dim, hidden_units, dropout, lr, n_blocks):
    inputs = layers.Input(shape=(input_dim,))

    # LỚP CHIẾU ( ĐƯA SỐ CHIỀU EMBEDDING LÀ 1280 LÊN SỐ CHIỀU TRONG CONFIG)
    x = layers.Dense(hidden_units, activation="gelu")(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(dropout)(x)

    # RESIDUAL BLOCK
    for i in range(n_blocks):
        shortcut = x
        
        # Layer 1
        res = layers.Dense(hidden_units, activation="gelu")(x)
        res = layers.BatchNormalization()(res)
        res = layers.Dropout(dropout)(res)
        
        # Layer 2: Activation linear để cộng với shortcut
        res = layers.Dense(hidden_units, activation="gelu")(res)
        res = layers.BatchNormalization()(res)
        
        # Add (ResNet Connection)
        x = layers.Add()([x, res])
        
        # Activation sau khi cộng
        x = layers.Activation('gelu')(x) # 0.273 voi relu

    # Output Layer
    # Multi-label classification -> dùng Sigmoid
    outputs = layers.Dense(output_dim, activation="sigmoid")(x) 

    # Tạo Model Object
    model = models.Model(inputs=inputs, outputs=outputs, name="ResNet_MLP")

    # Định nghĩa Schedule 
    steps_per_epoch = len(X_tr) // CONFIG["BATCH_SIZE"]
    
    lr_schedule = tf.keras.optimizers.schedules.CosineDecay(
        initial_learning_rate=1e-3,             # Bắt đầu
        decay_steps=CONFIG["EPOCHS"] * steps_per_epoch, # Giảm dần
        alpha=0.01                              # Kết thúc ở 1% LR ban đầu
    )

    # Khởi tạo Optimizer
    try:
        
        opt = optimizers.AdamW(learning_rate=lr_schedule, weight_decay=1e-2)
    except AttributeError:
        # NẾU LỖI DÙNG ADAM
        opt = optimizers.Adam(learning_rate=lr_schedule)

    # Compile 
    model.compile(
        optimizer=opt,
        loss=focal_loss_fixed(gamma=2.0, alpha=0.35),
        metrics=[metrics.BinaryAccuracy(name='acc'), metrics.AUC(name='auc')]
    )
    
    return model

In [None]:
# ------------------------------------------------------------
# 11. Khởi tạo và Run
# ------------------------------------------------------------
# Giả lập chiều dữ liệu
D = X_tr.shape[1] 
M = y_tr.shape[1]

model = build_resnet_model(
    input_dim=D, 
    output_dim=CONFIG["TOP_K_LABELS"],
    hidden_units=CONFIG["HIDDEN_UNITS"],
    dropout=CONFIG["DROPOUT"],
    lr=CONFIG["LEARNING_RATE"],
    n_blocks=2
)

model.summary()
# Checkpoint & Early Stopping
ckpt_path = "/kaggle/working/best_resnet_model.keras" 
es = callbacks.EarlyStopping(monitor="val_loss", patience=10, restore_best_weights=True, verbose=1)
mc = callbacks.ModelCheckpoint("/kaggle/working/best_resnet_model.keras", monitor="val_loss", save_best_only=True, verbose=1)

In [None]:
# ------------------------------------------------------------
# 12. Train 
# ------------------------------------------------------------

rlr = callbacks.ReduceLROnPlateau(
    monitor="val_loss", 
    factor=0.5,   
    patience=3,     
    verbose=1, 
    min_lr=1e-6
)

# Train
history = model.fit(
    X_tr, y_tr, 
    validation_data=(X_val, y_val),
    epochs=CONFIG["EPOCHS"], 
    batch_size=CONFIG["BATCH_SIZE"],
    callbacks=[es, mc, rlr],
    verbose=2,
    shuffle=True
)

In [None]:
# ------------------------------------------------------------
# 13. Evaluate & select global threshold
# ------------------------------------------------------------
def read_IA_safe(path):
    if not os.path.exists(path): return {}
    df=pd.read_csv(path, sep="\t", header=None, names=["go","ia"], dtype=str)
    d={}
    for _,r in df.iterrows():
        try: d[r.go]=float(r.ia)
        except: 
            try: d[r.go]=float(r.ia.replace(",",".")) 
            except: d[r.go]=0.0
    return d

ia_weights = read_IA_safe(CONFIG["IA_FILE"])

def weighted_precision_recall_f1(y_true, y_pred_bin, mlb_obj):
    tp = ((y_true==1)&(y_pred_bin==1)).sum(axis=0).astype(float)
    fp = ((y_true==0)&(y_pred_bin==1)).sum(axis=0).astype(float)
    fn = ((y_true==1)&(y_pred_bin==0)).sum(axis=0).astype(float)
    eps=1e-12
    prec = tp/(tp+fp+eps); rec = tp/(tp+fn+eps)
    f1 = 2*prec*rec/(prec+rec+eps)
    cls = mlb_obj.classes_
    weights = np.array([ia_weights.get(c,1.0) for c in cls], dtype=float)
    weighted_f1 = (f1*weights).sum()/(weights.sum()+eps)
    return weighted_f1

y_val_prob = model.predict(X_val, batch_size=CONFIG["BATCH_SIZE"], verbose=0)
best_thresh = 0.5; best_score = -1.0
if CONFIG["GLOBAL_THRESHOLD_SEARCH"]:
    for t in CONFIG["THRESHOLD_GRID"]:
        y_pred_bin = (y_val_prob >= t).astype(int)
        f1 = weighted_precision_recall_f1(y_val, y_pred_bin, mlb)
        if f1 > best_score: best_score=f1; best_thresh=t
print(f"[eval] best_thresh {best_thresh} best IA-weighted F1 {best_score}")

In [None]:
# ------------------------------------------------------------
# 14. Inference on TEST set using Pre-computed Embeddings
# ------------------------------------------------------------
print("[test] Loading pre-computed TEST embeddings...")
test_emb_full = np.load(CONFIG["TEST_EMBEDS"]) 
test_ids_full = np.load(CONFIG["TEST_IDS"])
N_test = len(test_ids_full)
print(f"[test] Loaded {N_test} test embeddings. Shape: {test_emb_full.shape}")

term_to_idx = {t:i for i,t in enumerate(mlb.classes_)}
restricted_parents = {}
for t in mlb.classes_:
    restricted_parents[t] = set([p for p in parents_map.get(t, set()) if p in term_to_idx])

def propagate_batch(pred_batch: np.ndarray, parents_map_local: Dict[str, Set[str]], classes_list: List[str], iterations=3):
    B, Mloc = pred_batch.shape
    idx_map = {i:classes_list[i] for i in range(Mloc)}
    term_to_idx_local = {classes_list[i]: i for i in range(Mloc)}
    for _ in range(iterations):
        changed = False
        for child_idx in range(Mloc):
            child_term = idx_map[child_idx]
            child_scores = pred_batch[:, child_idx]
            for pterm in parents_map_local.get(child_term, []):
                pidx = term_to_idx_local[pterm]
                mask = child_scores > pred_batch[:, pidx]
                if mask.any():
                    pred_batch[mask, pidx] = child_scores[mask]
                    changed = True
        if not changed: break
    return pred_batch

# Open submission file
out_fpath = CONFIG["OUTPUT_SUBMISSION"]
open(out_fpath, "w").close() 
out_f = open(out_fpath, "a")

print(f"[test] Predicting in batches of {CONFIG['PREDICT_BATCH_SIZE']}...")

for i in range(0, N_test, CONFIG["PREDICT_BATCH_SIZE"]):
    X_batch = test_emb_full[i : i + CONFIG["PREDICT_BATCH_SIZE"]]
    ids_batch = test_ids_full[i : i + CONFIG["PREDICT_BATCH_SIZE"]]
    
    # đoán
    y_batch_prob = model.predict(X_batch, batch_size=min(128, len(X_batch)), verbose=0)
    
    # lan truyền
    if CONFIG["PROPAGATE_PREDICTIONS"] and parents_map:
        y_batch_prob = propagate_batch(y_batch_prob, restricted_parents, list(mlb.classes_), iterations=3)
        
    # Write submission
    for ridx, pid in enumerate(ids_batch):
        probs = y_batch_prob[ridx]
        
        # Filter top-k
        top_k = CONFIG["TOP_K_PER_PROTEIN"]
        idxs = np.argsort(probs)[-top_k:]
        idxs = [int(x) for x in idxs if probs[x] > 1e-3]
        idxs = sorted(idxs, key=lambda x: probs[x], reverse=True)
        
        for idx in idxs:
            score = float(probs[idx])
            go_id = mlb.classes_[idx]
            out_f.write(f"{pid}\t{go_id}\t{score:.3f}\n")
            
    if (i // CONFIG["PREDICT_BATCH_SIZE"]) % 10 == 0:
        out_f.flush()
        print(f"[stream] processed {min(i + CONFIG['PREDICT_BATCH_SIZE'], N_test)} / {N_test}")

out_f.close()
print(f"[done] Submission written to {CONFIG['OUTPUT_SUBMISSION']}")