In [None]:
# ============================================================
# 1. IMPORT
# ============================================================

import os, gc, time
import numpy as np
import pandas as pd
from collections import defaultdict, Counter
from functools import lru_cache

import scipy.sparse as sp
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import MultiLabelBinarizer

[io] train_terms proteins: 82,405
[io] OBO parents nodes: 40,121
[prep] Propagating train labels (cached ancestors)...
[prep] chosen terms: 2,000
[io] Loading train embeddings...
[prep] Train proteins with labels: 82,404, X_train: (82404, 1280)
[prep] Y_csr shape: (82404, 2000), nnz: 2,349,146
[prep] Propagation edges (restricted): 2,065
[train] Fitting KNN index...
[train] KNN fitted in 0.05s
[test] Loading test embeddings (mmap)...
[test] N_test: 224,309, test_emb shape: (224309, 1280)
[stream] 4,096/224,309 done
[stream] 45,056/224,309 done


In [None]:
# ----------------------------
# 2. CONFIG
# ----------------------------
CONFIG = {
    "SAMPLE_SUBMISSION": "/kaggle/input/cafa-6-protein-function-prediction/sample_submission.tsv",
    "TRAIN_TERMS": "/kaggle/input/cafa-6-protein-function-prediction/Train/train_terms.tsv",
    "GO_OBO": "/kaggle/input/cafa-6-protein-function-prediction/Train/go-basic.obo",
    "IA_FILE": "/kaggle/input/cafa-6-protein-function-prediction/IA.tsv",

    "TRAIN_EMBEDS": "/kaggle/input/embedding-esm2-650m/biggest embedding/train_embeddings_650M.npy",
    "TRAIN_IDS": "/kaggle/input/embedding-esm2-650m/biggest embedding/train_ids.npy",
    "TEST_EMBEDS": "/kaggle/input/embedding-esm2-650m/biggest embedding/test_embeddings_650M.npy",
    "TEST_IDS": "/kaggle/input/embedding-esm2-650m/biggest embedding/test_ids.npy",

    "OUTPUT_SUBMISSION": "/kaggle/working/submission.tsv",

    # KNN
    "KNN_K": 50,
    "KNN_METRIC": "cosine",
    "KNN_SIGMA": 0.15,

    # Label space
    "TOP_K_LABELS": 2000,

    # Inference
    "PREDICT_BATCH_SIZE": 4096,     
    "TOP_K_PER_PROTEIN": 150,
    "MIN_SCORE": 0.001,            

    # Propagation (optional)
    "PROPAGATE_TRAIN_LABELS": True,
    "PROPAGATE_PREDICTIONS": True,
    "PROP_PASSES": 2,

    "RANDOM_SEED": 42,
}

In [None]:
# ----------------------------
# 3. HÀM HỖ TRỢ
# ----------------------------
def clean_ids(arr: np.ndarray) -> np.ndarray:
    """Convert bytes->str and split 'tax|ID' -> 'ID' if needed."""
    arr = np.asarray(arr)
    out = []
    for x in arr:
        if isinstance(x, (bytes, np.bytes_)):
            x = x.decode("utf-8", errors="ignore")
        else:
            x = str(x)
        if "|" in x:
            parts = x.split("|")
            if len(parts) >= 2:
                x = parts[1]
        out.append(x)
    return np.array(out, dtype=object)

def l2_normalize_inplace(X: np.ndarray, chunk: int = 8192) -> None:
    """In-place L2 normalize rows (float32)."""
    for i in range(0, X.shape[0], chunk):
        sub = X[i:i+chunk]
        norms = np.linalg.norm(sub, axis=1, keepdims=True)
        sub /= np.maximum(norms, 1e-12)

def l2_normalize_copy(X: np.ndarray) -> np.ndarray:
    X = np.asarray(X, dtype=np.float32)
    norms = np.linalg.norm(X, axis=1, keepdims=True)
    return X / np.maximum(norms, 1e-12)

def read_train_terms_fast(path: str) -> dict:
    df = pd.read_csv(path, sep="\t", header=None, names=["protein", "go", "ont"], dtype=str)
    mapping = df.groupby("protein")["go"].apply(list).to_dict()
    print(f"[io] train_terms proteins: {len(mapping):,}")
    return mapping

def read_ia(path: str) -> dict:
    df = pd.read_csv(path, sep="\t", header=None, names=["go", "score"])
    df["score"] = df["score"].astype(float)
    return dict(zip(df["go"].values, df["score"].values))

def parse_obo_parents(go_obo_path: str) -> dict:
    parents = defaultdict(set)
    if not os.path.exists(go_obo_path):
        return parents
    cur_id = None
    with open(go_obo_path, "r") as f:
        for line in f:
            line = line.strip()
            if line == "[Term]":
                cur_id = None
            elif line.startswith("id: "):
                cur_id = line.split("id: ")[1].strip()
            elif line.startswith("is_a: "):
                pid = line.split()[1].strip()
                if cur_id:
                    parents[cur_id].add(pid)
            elif line.startswith("relationship: part_of "):
                parts = line.split()
                if len(parts) >= 3 and cur_id:
                    parents[cur_id].add(parts[2].strip())
    print(f"[io] OBO parents nodes: {len(parents):,}")
    return parents

def build_edges(parents_map: dict, classes: np.ndarray) -> list:
    term_to_idx = {t: i for i, t in enumerate(classes)}
    edges = []
    for child, ps in parents_map.items():
        if child not in term_to_idx:
            continue
        c = term_to_idx[child]
        for p in ps:
            if p in term_to_idx:
                edges.append((c, term_to_idx[p]))
    return edges

def propagate_dense(scores: np.ndarray, edges: list, passes: int = 2) -> np.ndarray:
    # scores: (B, C)
    for _ in range(passes):
        for c, p in edges:
            # parent = max(parent, child)
            np.maximum(scores[:, p], scores[:, c], out=scores[:, p])
    return scores

def choose_top_terms(train_terms: dict, ia_dict: dict, top_k: int) -> list:
    freq = Counter()
    for _, terms in train_terms.items():
        freq.update(terms)

    # IA-weighted frequency
    scored = []
    for t, f in freq.items():
        scored.append((t, float(ia_dict.get(t, 0.0)) * f))
    scored.sort(key=lambda x: x[1], reverse=True)

    chosen = [t for t, _ in scored[:top_k]]
    print(f"[prep] chosen terms: {len(chosen):,}")
    return chosen



In [None]:
# ----------------------------
# 4. Sparse weighted voting
# ----------------------------
def knn_weighted_vote_sparse(Y_csr: sp.csr_matrix,
                            neighbor_idx: np.ndarray,
                            dists: np.ndarray,
                            sigma: float) -> sp.csr_matrix:
    """
    Y_csr: (N_train, C) sparse labels
    neighbor_idx: (B, K) indices into train
    dists: (B, K) cosine distances
    returns scores CSR (B, C) where each row is weighted average of neighbors' labels
    """
    B, K = neighbor_idx.shape
    w = np.exp(-dists / sigma).astype(np.float32)
    w /= (w.sum(axis=1, keepdims=True) + 1e-12)

    idx_flat = neighbor_idx.reshape(-1)
    w_flat = w.reshape(-1)

    # Gather neighbor label rows: (B*K, C)
    Y_nb = Y_csr[idx_flat]                 # CSR
    Y_nb = Y_nb.multiply(w_flat[:, None])  # scale rows by weights (no huge dense tensor)

    # Group-sum every K rows -> B rows using a CSR "grouping" matrix
    # Each output row sums its corresponding block of K rows
    indptr = np.arange(0, B * K + 1, K, dtype=np.int32)
    indices = np.arange(B * K, dtype=np.int32)
    data = np.ones(B * K, dtype=np.float32)
    G = sp.csr_matrix((data, indices, indptr), shape=(B, B * K))

    scores = G @ Y_nb  # (B, C) CSR
    return scores




In [None]:
# ============================================================
# 5. LOAD TRAIN TERMS + OBO + IA
# ============================================================
train_terms = read_train_terms_fast(CONFIG["TRAIN_TERMS"])
parents_map = parse_obo_parents(CONFIG["GO_OBO"])
ia_dict = read_ia(CONFIG["IA_FILE"])

# Optional: propagate train labels up GO graph (ancestor closure)
if CONFIG["PROPAGATE_TRAIN_LABELS"] and len(parents_map) > 0:
    print("[prep] Propagating train labels (cached ancestors)...")

    @lru_cache(maxsize=None)
    def ancestors(term: str) -> tuple:
        stack = [term]
        seen = set()
        while stack:
            cur = stack.pop()
            for p in parents_map.get(cur, ()):
                if p not in seen:
                    seen.add(p)
                    stack.append(p)
        return tuple(seen)

    for p, terms in train_terms.items():
        s = set(terms)
        for t in list(s):
            s.update(ancestors(t))
        train_terms[p] = list(s)

# Choose label space
chosen_terms = choose_top_terms(train_terms, ia_dict, CONFIG["TOP_K_LABELS"])




In [None]:
# ============================================================
# 6. LOAD TRAIN EMBEDDINGS + FILTER + NORMALIZE
# ============================================================
print("[io] Loading train embeddings...")
train_emb_full = np.load(CONFIG["TRAIN_EMBEDS"])  # usually fits RAM
train_ids_full = clean_ids(np.load(CONFIG["TRAIN_IDS"]))

# Filter to proteins that have labels
valid_mask = np.array([pid in train_terms for pid in train_ids_full], dtype=bool)
valid_indices = np.where(valid_mask)[0]
X_train = np.asarray(train_emb_full[valid_indices], dtype=np.float32)  # copy filtered to RAM
train_pids = train_ids_full[valid_indices].tolist()

del train_emb_full, train_ids_full, valid_mask
gc.collect()

print(f"[prep] Train proteins with labels: {len(train_pids):,}, X_train: {X_train.shape}")

# Normalize train embeddings
l2_normalize_inplace(X_train)
gc.collect()

# Build sparse label matrix (CSR)
y_labels = [[t for t in train_terms[pid] if t in set(chosen_terms)] for pid in train_pids]

mlb = MultiLabelBinarizer(classes=chosen_terms, sparse_output=True)
Y_csr = mlb.fit_transform(y_labels).tocsr().astype(np.float32)
classes = mlb.classes_

print(f"[prep] Y_csr shape: {Y_csr.shape}, nnz: {Y_csr.nnz:,}")

# Build propagation edges restricted to chosen terms
edges = []
if CONFIG["PROPAGATE_PREDICTIONS"] and len(parents_map) > 0:
    edges = build_edges(parents_map, classes)
    print(f"[prep] Propagation edges (restricted): {len(edges):,}")




In [None]:
# ============================================================
# 7. FIT KNN (cosine, brute)
# ============================================================
print("[train] Fitting KNN index...")
knn = NearestNeighbors(
    n_neighbors=CONFIG["KNN_K"],
    metric=CONFIG["KNN_METRIC"],
    algorithm="brute",      # good for cosine + high-dim embeddings
    n_jobs=-1
)
t0 = time.time()
knn.fit(X_train)
print(f"[train] KNN fitted in {time.time() - t0:.2f}s")




In [None]:
# ============================================================
# 8. INFERENCE ON TEST (memmap) + WRITE SUBMISSION
# ============================================================
print("[test] Loading test embeddings (mmap)...")
test_emb = np.load(CONFIG["TEST_EMBEDS"], mmap_mode="r")
test_ids = clean_ids(np.load(CONFIG["TEST_IDS"]))
N_test = len(test_ids)
print(f"[test] N_test: {N_test:,}, test_emb shape: {test_emb.shape}")

# Detect header from sample submission
sample_path = CONFIG["SAMPLE_SUBMISSION"]
write_header = False
header_cols = None
try:
    sample_df = pd.read_csv(sample_path, sep="\t", nrows=5)
    header_cols = list(sample_df.columns)
    # If sample has 3 columns and looks like header, we'll write it.
    # (If Kaggle expects no header, set write_header=False)
    write_header = True
except:
    write_header = False

out_path = CONFIG["OUTPUT_SUBMISSION"]
with open(out_path, "w", encoding="utf-8") as f:
    if write_header and header_cols is not None and len(header_cols) >= 3:
        f.write("\t".join(header_cols[:3]) + "\n")

    B = CONFIG["PREDICT_BATCH_SIZE"]
    K = CONFIG["KNN_K"]
    sigma = float(CONFIG["KNN_SIGMA"])
    topk = int(CONFIG["TOP_K_PER_PROTEIN"])
    min_score = float(CONFIG["MIN_SCORE"])

    for start in range(0, N_test, B):
        end = min(start + B, N_test)

        Xb = np.asarray(test_emb[start:end], dtype=np.float32)
        Xb = l2_normalize_copy(Xb)

        # neighbors in train
        dists, idxs = knn.kneighbors(Xb, n_neighbors=K)

        # sparse voting
        scores_csr = knn_weighted_vote_sparse(Y_csr, idxs, dists, sigma=sigma)

        # convert to dense for fast top-k
        scores = scores_csr.toarray().astype(np.float32, copy=False)

        # optional propagation on dense
        if edges:
            scores = propagate_dense(scores, edges, passes=int(CONFIG["PROP_PASSES"]))

        # Top-k indices per row (vectorized)
        k_eff = min(topk, scores.shape[1])
        top_idx = np.argpartition(scores, -k_eff, axis=1)[:, -k_eff:]
        top_scores = np.take_along_axis(scores, top_idx, axis=1)
        order = np.argsort(top_scores, axis=1)[:, ::-1]
        top_idx = np.take_along_axis(top_idx, order, axis=1)
        top_scores = np.take_along_axis(top_scores, order, axis=1)

        # Write lines (buffered)
        lines = []
        ids_batch = test_ids[start:end]
        for i, pid in enumerate(ids_batch):
            # filter by min_score
            mask = top_scores[i] >= min_score
            if not np.any(mask):
                # ensure at least 1 prediction
                j = int(top_idx[i, 0])
                lines.append(f"{pid}\t{classes[j]}\t{float(top_scores[i,0]):.3f}\n")
                continue

            for j, sc in zip(top_idx[i, mask], top_scores[i, mask]):
                lines.append(f"{pid}\t{classes[int(j)]}\t{float(sc):.3f}\n")

        f.write("".join(lines))

        if (start // B) % 10 == 0:
            print(f"[stream] {end:,}/{N_test:,} done")

print(f"[done] Wrote submission: {out_path}")
