In [1]:
import random
import numpy as np
import torch
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

In [2]:

import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split, ConcatDataset
from transformers import AutoTokenizer, AutoModel

from sklearn.metrics.pairwise import cosine_similarity
from sklearn.feature_extraction.text import TfidfVectorizer
import re



def load_label_file(path: str) -> str:
    """key: value1,value2,... ÌòïÏãùÏúºÎ°ú Îêú .txt ÌååÏùºÏùÑ ÌÜµÏß∏Î°ú ÏùΩÏñ¥ÏÑú Î¨∏ÏûêÏó¥Î°ú Î∞òÌôò"""
    with open(path, "r", encoding="utf-8") as f:
        return f.read()

def parse_key_value_lines(text: str):
    """'key:val1,val2,...' Ïó¨Îü¨ Ï§ÑÏùÑ ÎîïÏÖîÎÑàÎ¶¨Î°ú Î≥ÄÌôò"""
    id2label = {}
    for line in text.splitlines():
        line = line.strip()
        if not line or ":" not in line:
            continue
        key, vals = line.split(":", 1)
        id2label[key.strip()] = vals.strip()
    return id2label

def preprocess_label_text(label_path_str: str):
    cleaned = label_path_str.lower()
    cleaned = re.sub(r"[:,]", " ", cleaned)
    cleaned = re.sub(r"_", " ", cleaned)
    cleaned = re.sub(r"[^a-z0-9 ]", " ", cleaned)
    cleaned = re.sub(r"\s+", " ", cleaned).strip()
    return cleaned

def build_tfidf_vectorizer(label_texts):
    vectorizer = TfidfVectorizer()
    label_tfidf = vectorizer.fit_transform(label_texts)
    return vectorizer, label_tfidf

def compute_lexical_similarity(doc_text, vectorizer, label_tfidf):
    doc_vec = vectorizer.transform([doc_text])
    sims = cosine_similarity(doc_vec, label_tfidf)[0]
    return sims

    
label_raw_text = load_label_file("Amazon_products/class_related_keywords.txt")  # ÎÑ§ ÌååÏùº Ïù¥Î¶ÑÏóê ÎßûÏ∂∞ Î∞îÍøî
id2label = parse_key_value_lines(label_raw_text)

# 2) ÎùºÎ≤® ÌÖçÏä§Ìä∏ Ï†ÑÏ≤òÎ¶¨Ìï¥ÏÑú TF-IDF ÌïôÏäµ
label_keys = list(id2label.keys())
label_texts = [
    preprocess_label_text(f"{k} {id2label[k]}")
    for k in label_keys
]
vectorizer, label_tfidf = build_tfidf_vectorizer(label_texts)

# 3) ÌÖåÏä§Ìä∏Ïö© Î¨∏ÏÑú ÌïòÎÇò ÎÑ£Ïñ¥Î≥¥Í∏∞
doc = "gourmet organic chocolate snack"
doc_clean = preprocess_label_text(doc)
sims = compute_lexical_similarity(doc_clean, vectorizer, label_tfidf)

# 4) Í≤∞Í≥º Î≥¥Í∏∞
label_sims = list(zip(label_keys, sims))
label_sims.sort(key=lambda x: x[1], reverse=True)

for lbl, score in label_sims:
    print(lbl, round(score, 4))


chocolate_bars 0.5059
chocolate_gifts 0.4276
chocolate 0.3754
chocolate_covered_fruit 0.3548
dried_fruit_raisins 0.2765
chocolate_pretzels 0.2373
fresh_baked_cookies 0.2332
grocery_gourmet_food 0.2266
snack_gifts 0.2254
chocolate_assortments 0.2173
candy_chocolate 0.1635
hot_cocoa 0.1609
food 0.1457
gourmet_gifts 0.1342
snack_food 0.121
trail_mix 0.1181
granola_trail_mix_bars 0.1044
fruit_leather 0.1023
toaster_pastries 0.1015
cookies 0.0996
fruit 0.0941
raisins 0.0925
meat_poultry 0.0903
marshmallows 0.087
changing_table_pads_covers 0.0863
popcorn 0.085
granola_bars 0.0846
produce 0.0833
solid_feeding 0.0831
milk 0.0827
chocolate_truffles 0.0822
rice_cakes 0.0777
nutrition_wellness 0.0746
party_mix 0.0734
p_t_s 0.0716
fruit_gifts 0.0684
sensual_delights 0.0649
foie_gras_p_t_s 0.062
sugars_sweeteners 0.0604
salsas 0.0594
eggs 0.059
cakes 0.0569
nutrition_bars_drinks 0.0569
chocolate_covered_nuts 0.0565
dessert_gifts 0.0551
spices_gifts 0.0545
meat_gifts 0.0541
crackers 0.0533
juices 0.

In [3]:
def build_label_embeddings(label_keys, label_tfidf, dense: bool = True):
    """
    label_keys: ÎùºÎ≤® Ïù¥Î¶Ñ Î¶¨Ïä§Ìä∏ (vectorizeÌï† Îïå ÏçºÎçò ÏàúÏÑúÏôÄ Í∞ôÏïÑÏïº Ìï®)
    label_tfidf: shape = (n_labels, vocab_size) Ïù∏ sparse matrix
    dense: TrueÎ©¥ numpy arrayÎ°ú Î∞îÍøîÏÑú ÎèåÎ†§Ï§å

    return:
        dict: {label_name: embedding_vector}
    """
    embeddings = {}
    if dense:
        label_tfidf_dense = label_tfidf.toarray()
        for i, label in enumerate(label_keys):
            embeddings[label] = label_tfidf_dense[i]
    else:
        # sparse Í∑∏ÎåÄÎ°ú
        for i, label in enumerate(label_keys):
            embeddings[label] = label_tfidf[i]
    return embeddings

label_embeddings = build_label_embeddings(label_keys, label_tfidf, dense=True)
print(label_embeddings["grocery_gourmet_food"].shape)  # (vocab_size,)


(3466,)


In [4]:

def load_edges(path):
    edges = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            raw = line.strip()
            if not raw or raw.startswith("#"):
                continue
            parts = raw.split()
            if len(parts) < 2:
                continue
            try:
                u, v = int(parts[0]), int(parts[1])
            except ValueError:
                continue
            edges.append((u, v))
    return edges

def find_roots(edges):
    parents = set()
    children = set()
    for u, v in edges:
        parents.add(u)
        children.add(v)
    # Î∂ÄÎ™®Î°úÎßå ÎÇòÏò® Ïï†Îì§ = Î£®Ìä∏Îì§
    roots = parents - children
    return sorted(roots)

# --- ÏÇ¨Ïö© ---
E = load_edges("Amazon_products/class_hierarchy.txt")

N = 531
A = np.zeros((N, N), dtype=np.uint8)
for u, v in E:
    A[u, v] = 1
    A[v, u] = 1   # ÌÉêÏÉâÏö©ÏúºÎ°úÎäî Î¨¥Î∞©Ìñ• Ïù∏Ï†ëÌñâÎ†¨ Ïç®ÎèÑ Îê®

B = np.zeros((N, N), dtype=np.uint8)
for u, v in E:
    B[u, v] = 1

roots = find_roots(E)
print("roots:", roots)


roots: [0, 3, 10, 23, 40, 169]


In [6]:

# ---------------------------
# GAT 
# ---------------------------

class SimpleGATLayer(nn.Module):
    def __init__(self, in_dim, out_dim, heads=4, concat=True, dropout=0.2, negative_slope=0.2, residual=True):
        super().__init__()
        self.heads = heads
        self.out_dim = out_dim
        self.concat = concat
        self.dropout = nn.Dropout(dropout)
        self.leaky_relu = nn.LeakyReLU(negative_slope)
        self.lin = nn.Linear(in_dim, heads * out_dim, bias=False)
        self.a_src = nn.Parameter(torch.Tensor(heads, out_dim))
        self.a_dst = nn.Parameter(torch.Tensor(heads, out_dim))
        self.residual = residual
        if residual and (in_dim == (heads * out_dim if concat else out_dim)):
            self.res_proj = nn.Identity()
        elif residual:
            self.res_proj = nn.Linear(in_dim, heads * out_dim if concat else out_dim, bias=False)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.lin.weight)
        nn.init.xavier_uniform_(self.a_src)
        nn.init.xavier_uniform_(self.a_dst)
        if self.residual and not isinstance(getattr(self, "res_proj", None), nn.Identity):
            nn.init.xavier_uniform_(self.res_proj.weight)

    def forward(self, x, adj):
        """
        x: [N, Fin]
        adj: [N, N] (0/1; self-loop ÏóÜÏùå)
        """
        N = x.size(0)
        Wh = self.lin(x).view(N, self.heads, self.out_dim)  # [N, H, F]

        e_src = (Wh * self.a_src).sum(dim=-1)  # [N, H]
        e_dst = (Wh * self.a_dst).sum(dim=-1)  # [N, H]
        e = e_src.unsqueeze(1) + e_dst.unsqueeze(0)  # [N, N, H]
        e = self.leaky_relu(e)
        # --- ÏïàÏ†ÑÌïú masked softmax ---
        mask = (adj > 0).unsqueeze(-1)                    # [N, N, 1]
        e = e.masked_fill(~mask, -1e9)                    # -inf ÎåÄÏã† -1e9Î°ú NaN Î∞©ÏßÄ
        alpha = torch.softmax(e, dim=1)                   # ÏÜåÌîÑÌä∏Îß•Ïä§
        alpha = alpha * mask.float()                      # ÎßàÏä§ÌÅ¨Î°ú 0 Ï≤òÎ¶¨
        denom = alpha.sum(dim=1, keepdim=True).clamp(min=1e-12)  # Ïù¥ÏõÉ ÏóÜÏùÑ Îïå 0 Î∂ÑÎ™® Î∞©ÏßÄ
        alpha = alpha / denom                             # Ïù¥ÏõÉÎì§Î°ú Ï†ïÍ∑úÌôî

        out = torch.einsum("ijh,jhf->ihf", alpha, Wh)     # [N, H, F]
        out = out.reshape(N, self.heads * self.out_dim) if self.concat else out.mean(dim=1)
        out = self.dropout(out)
        if self.residual:
            out = out + self.res_proj(x)                  # self-loop ÏóÜÎäî ÎåÄÏã† residualÎ°ú ÏûêÍ∏∞Ï†ïÎ≥¥ Ïú†ÏßÄ
        return out

class GATEncoder(nn.Module):
    def __init__(self, in_dim, hid_dim=64, out_dim=768, heads1=4, heads2=4, dropout=0.2):
        super().__init__()
        self.gat1 = SimpleGATLayer(in_dim, hid_dim, heads=heads1, concat=True,  dropout=dropout, residual=True)
        self.gat2 = SimpleGATLayer(hid_dim*heads1, out_dim, heads=heads2, concat=False, dropout=dropout, residual=True)
        self.act = nn.ELU(); self.dropout = nn.Dropout(dropout)
    def forward(self, x, adj):
        h = self.gat1(x, adj); h = self.act(h); h = self.dropout(h)
        z = self.gat2(h, adj)
        return z  # [N, out_dim]


In [7]:
import pandas as pd

In [8]:
# ---------------------------
# ÌïôÏäµ Ïú†Ìã∏: ÏùåÏÑ± Ïó£ÏßÄ ÏÉòÌîå/Î°úÏä§
# ---------------------------
def to_upper_pos_edges(A):
    pos = []
    N = A.shape[0]
    for i in range(N):
        for j in range(i+1, N):
            if A[i, j] == 1:
                pos.append((i, j))
    return pos

def sample_neg(A, k):
    N = A.shape[0]
    neg = set()
    while len(neg) < k:
        u = np.random.randint(0, N); v = np.random.randint(0, N)
        if u == v: continue
        a, b = (u, v) if u < v else (v, u)
        if A[a, b] == 0:
            neg.add((a, b))
    return list(neg)

def sample_neg_excluding(A, k, exclude_edges):
    """
    A: np.array [N,N]  (0/1)
    k: ÎΩëÏùÑ ÏùåÏÑ± Í∞úÏàò
    exclude_edges: {(u,v), ...}  Î¨¥Ï°∞Í±¥ ÎπºÏïº ÌïòÎäî ÏñëÏÑ±(ÎòêÎäî Í∏àÏßÄ) Ïó£ÏßÄÎì§ (u<v ÌòïÌÉúÎ°ú ÎÑ£Í∏∞)
    """
    N = A.shape[0]
    neg = set()
    while len(neg) < k:
        u = np.random.randint(0, N); v = np.random.randint(0, N)
        if u == v:
            continue
        a, b = (u, v) if u < v else (v, u)
        if A[a, b] == 0 and (a, b) not in exclude_edges:
            neg.add((a, b))
    return list(neg)


def edge_score(z, edges):
    u = torch.tensor([a for a, _ in edges], device=z.device, dtype=torch.long)
    v = torch.tensor([b for _, b in edges], device=z.device, dtype=torch.long)
    return (z[u] * z[v]).sum(dim=1)  # ÎÇ¥Ï†Å ÎîîÏΩîÎçî

from sklearn.metrics import roc_auc_score
def eval_auc(z, pos_edges, A_full, k_factor=1.0):
    z = F.normalize(z, p=2, dim=1)
    neg_edges = sample_neg(A_full, int(len(pos_edges) * k_factor))
    s = torch.cat([edge_score(z, pos_edges), edge_score(z, neg_edges)]).detach().cpu().numpy()
    y = np.concatenate([np.ones(len(pos_edges)), np.zeros(len(neg_edges))])
    return roc_auc_score(y, s)

hidden_dim=64
out_dim=3466
heads1=8
heads2=8
dropout=0.2
epochs=200
lr=1e-3
weight_decay=5e-4
neg_ratio=1.0
eval_every=20
use_full_graph_for_final=True
pad_width=2
normalize_out = True
device = "cuda" if torch.cuda.is_available() else "cpu"

ids = np.arange(len(label_keys), dtype=np.int64)
X = np.vstack([label_embeddings[k] for k in label_keys]).astype(np.float32)
X = torch.tensor(X, dtype=torch.float32, device=device)

N, d0 = X.shape
pos_edges = to_upper_pos_edges(A)

rng = np.random.default_rng(42)
idx = rng.permutation(len(pos_edges))
n_val = max(1, int(0.1 * len(pos_edges)))          # 10% val
pos_val = [pos_edges[i] for i in idx[:n_val]]
pos_train = [pos_edges[i] for i in idx[n_val:]]

# train Í∑∏ÎûòÌîÑÎßåÏúºÎ°ú ÌïôÏäµ(ÎàÑÏ∂ú Î∞©ÏßÄ)
A_train = np.zeros_like(A)
for u, v in pos_train:
    A_train[u, v] = 1; A_train[v, u] = 1

adj_train = torch.tensor(A_train, dtype=torch.float32, device=device)
# ÌÖêÏÑú
x = torch.tensor(X, dtype=torch.float32, device=device)
adj = torch.tensor(A, dtype=torch.float32, device=device)  # softmax ÎßàÏä§ÌÅ¨Ïö©

model = GATEncoder(in_dim=d0, hid_dim=hidden_dim, out_dim=out_dim, heads1=heads1, heads2=heads2, dropout=dropout).to(device)

opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
bce = nn.BCEWithLogitsLoss()

maxauc = 0
best_ckpt = "Amazon_products/best_gat.ckpt"
# Í∏àÏßÄ Ïó£ÏßÄ ÏßëÌï© (train+val Î™®Îëê)
forbidden = set()
for u, v in pos_edges:        # pos_edges = train+val Ï†ÑÏ≤¥
    a, b = (u, v) if u < v else (v, u)
    forbidden.add((a, b))

for ep in range(1, epochs+1):
    model.train()
    # üî¥ Ïó¨Í∏∞ÏÑú Ï†ÑÏ≤¥ adj ÎßêÍ≥† trainÏö© adjÎßå Î≥∏Îã§
    z = model(x, adj_train)                          # [N, out_dim]
    if normalize_out:
        z = F.normalize(z, p=2, dim=1)

    # üî¥ Ïã§Ï†úÎ°ú ÌïôÏäµÏóê Ïì∞Îäî ÏñëÏÑ± Ïàò Í∏∞Ï§ÄÏúºÎ°ú ÏùåÏÑ± Ïàò Í≤∞Ï†ï
    num_pos = len(pos_train)
    num_neg = int(num_pos * neg_ratio)
    # üî¥ train Í∑∏ÎûòÌîÑ Í∏∞Ï§ÄÏúºÎ°ú ÎΩëÎêò, train+val ÏñëÏÑ±ÏùÄ Î¨¥Ï°∞Í±¥ Ï†úÏô∏
    neg_edges = sample_neg_excluding(A_train, num_neg, forbidden)

    score_pos = edge_score(z, pos_train)
    score_neg = edge_score(z, neg_edges)
    scores = torch.cat([score_pos, score_neg], dim=0)
    labels = torch.cat([torch.ones_like(score_pos), torch.zeros_like(score_neg)], dim=0)

    loss = bce(scores, labels)

    opt.zero_grad()
    loss.backward()
    opt.step()

    # ÌèâÍ∞Ä Î∂ÄÎ∂ÑÏùÄ Í±∞Ïùò Í∑∏ÎåÄÎ°ú
    if ep % 1 == 0 or ep == 1:
        model.eval()
        with torch.no_grad():
            # valÏùÄ Ïó¨Ï†ÑÌûà train Í∑∏ÎûòÌîÑÎ°ú ÏûÑÎ≤†Îî©
            z_val = F.normalize(model(x, adj_train), p=2, dim=1)
            auc_val = eval_auc(z_val, pos_val, A, k_factor=1.0)
        print(f"[{ep:03d}/{epochs}] loss={loss.item():.4f} | "
              f"pos={score_pos.mean().item():.3f} neg={score_neg.mean().item():.3f} | "
              f"val AUC={auc_val:.4f}")
        if maxauc < auc_val:
            maxauc = auc_val
            torch.save(model.state_dict(), best_ckpt)

model.load_state_dict(torch.load(best_ckpt, weights_only=True))

# ÏµúÏ¢Ö ÏûÑÎ≤†Îî© Ï∂îÏ∂ú
model.eval()
with torch.no_grad():
    z = model(x, adj)
    if normalize_out:
        z = F.normalize(z, p=2, dim=1)
    Z = z.detach().cpu().numpy()  # [N, out_dim]
OUT_CSV = "Amazon_products/label_emb_tf"
# CSV Ï†ÄÏû• (id + feat00..)
pad = max(2, len(str(out_dim-1)))
feat_cols = [f"feat{str(i).zfill(pad)}" for i in range(out_dim)]
df = pd.DataFrame(Z, columns=feat_cols)
df.insert(0, "id", ids)
df.to_csv(OUT_CSV, index=False)
print(f"[OK] saved GAT label embeddings ‚Üí {OUT_CSV}  shape={df.shape}")



[001/700] loss=0.6678 | pos=0.120 neg=0.010 | val AUC=0.7806
[002/700] loss=0.6270 | pos=0.303 neg=0.011 | val AUC=0.8425
[003/700] loss=0.5963 | pos=0.452 neg=0.009 | val AUC=0.8243
[004/700] loss=0.5750 | pos=0.566 neg=0.008 | val AUC=0.8485
[005/700] loss=0.5626 | pos=0.652 neg=0.016 | val AUC=0.8638
[006/700] loss=0.5477 | pos=0.712 neg=-0.003 | val AUC=0.8221
[007/700] loss=0.5384 | pos=0.762 neg=-0.012 | val AUC=0.8323
[008/700] loss=0.5343 | pos=0.795 neg=-0.010 | val AUC=0.8536
[009/700] loss=0.5267 | pos=0.820 neg=-0.029 | val AUC=0.8093
[010/700] loss=0.5323 | pos=0.842 neg=0.002 | val AUC=0.8473
[011/700] loss=0.5279 | pos=0.859 neg=-0.008 | val AUC=0.8042
[012/700] loss=0.5269 | pos=0.872 neg=-0.005 | val AUC=0.8508
[013/700] loss=0.5281 | pos=0.881 neg=0.004 | val AUC=0.8217
[014/700] loss=0.5197 | pos=0.889 neg=-0.026 | val AUC=0.8689
[015/700] loss=0.5234 | pos=0.894 neg=-0.012 | val AUC=0.8383
[016/700] loss=0.5224 | pos=0.901 neg=-0.015 | val AUC=0.7924
[017/700] loss=

In [5]:
def load_docs_txt(path):
    """
    'idx<TAB>text' ÌòïÌÉúÏùò ÌååÏùºÏùÑ ÏùΩÏñ¥ÏÑú
    ids: [int, ...]
    texts: [str, ...]
    ÏùÑ Î¶¨ÌÑ¥
    """
    ids = []
    texts = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            # ÌÉ≠ Í∏∞Ï§Ä
            idx_str, txt = line.split("\t", 1)
            ids.append(int(idx_str))
            texts.append(txt)
    return ids, texts



def build_doc_embeddings_from_existing_vectorizer(doc_texts, vectorizer):
    """
    doc_texts: Ï†ÑÏ≤òÎ¶¨ Ï†ÑÏùò ÏõêÎ¨∏ Î¶¨Ïä§Ìä∏
    vectorizer: ÎùºÎ≤®Ïóê ÎåÄÌï¥ fitÎêòÏñ¥ ÏûàÎäî TfidfVectorizer
    return: dense numpy array [N_docs, vocab]
    """
    # ÎùºÎ≤®Ïù¥Îûë ÎèôÏùº Í∑úÏπôÏúºÎ°ú Ï†ÑÏ≤òÎ¶¨
    cleaned_docs = [preprocess_label_text(t) for t in doc_texts]
    doc_tfidf = vectorizer.transform(cleaned_docs)   # sparse
    doc_emb = doc_tfidf.toarray().astype(np.float32)
    return doc_emb

# ÏÇ¨Ïö© ÏòàÏãú
# 1) Î¨∏ÏÑú ÏùΩÍ∏∞
doc_ids, doc_texts = load_docs_txt("Amazon_products/train/train_corpus.txt")

# 2) ÎùºÎ≤® Îïå ÎßåÎì† vectorizer Ïû¨ÏÇ¨Ïö©Ìï¥ÏÑú ÏûÑÎ≤†Îî© ÎßåÎì§Í∏∞
doc_embeddings = build_doc_embeddings_from_existing_vectorizer(doc_texts, vectorizer)

In [6]:
print(doc_embeddings.shape)

(29487, 3466)


In [6]:
N = 531 
B = np.zeros((N, N), dtype=np.uint8)

for u, v in E:
    B[u, v] = 1
print(B)
print(roots)

[[0 1 0 ... 0 0 0]
 [0 0 1 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]]
[0, 3, 10, 23, 40, 169]


In [7]:
import numpy as np
import numpy as np

def hierarchical_beam_similarity_avg(
    doc_vec: np.ndarray,
    label_emb: np.ndarray,
    adj_upper: np.ndarray,
    roots: list[int] = [0],       # Ïó¨Îü¨ Î£®Ìä∏
    beam: int = 5,
    per_parent: str | int = "l+2",
    tau: float = 0.35,
    eps: float = 1e-9,
    max_depth: int | None = None,
    normalize: bool = False,      # ÌïÑÏöîÌïòÎ©¥ TrueÎ°ú
):
    doc = np.asarray(doc_vec, dtype=np.float32)
    L = np.asarray(label_emb, dtype=np.float32)
    A = np.asarray(adj_upper).astype(bool)
    N, d = L.shape

    if normalize:
        doc = doc / (np.linalg.norm(doc) + eps)
        L = L / (np.linalg.norm(L, axis=1, keepdims=True) + eps)

    # Î°úÏª¨ Ï†êÏàò
    sims = L @ doc
    p = 1.0 / (1.0 + np.exp(-sims / max(tau, 1e-6)))

    children = [np.flatnonzero(A[i]) for i in range(N)]

    S = np.full(N, -np.inf, dtype=np.float32)
    K = np.full(N, -np.inf, dtype=np.float32)
    Llen = np.zeros(N, dtype=np.int32)

    roots = list(roots)
    for r in roots:
        S[r] = 0.0
        Llen[r] = 0
        K[r] = -np.inf

    levels = [roots[:]]
    cur = roots[:]
    level_id = 0

    while True:
        cand_best = {}
        k_parent = (level_id + 2) if (per_parent == "l+2") else int(per_parent)

        for par in cur:
            ch = children[par]
            if ch.size == 0:
                continue
            if ch.size > k_parent:
                idx = np.argpartition(-sims[ch], k_parent - 1)[:k_parent]
                ch = ch[idx]
            for c in ch:
                S_c = S[par] + float(p[c])
                L_c = Llen[par] + 1
                K_c = S_c / (L_c + eps)
                if (c not in cand_best) or (K_c > cand_best[c][2]):
                    cand_best[c] = (S_c, L_c, K_c)

        if not cand_best:
            break

        kept = sorted(cand_best.items(), key=lambda x: x[1][2], reverse=True)[:min(beam, len(cand_best))]
        next_level = [i for i, _ in kept]
        for i, (Si, Li, Ki) in kept:
            S[i], Llen[i], K[i] = Si, Li, Ki

        levels.append(next_level)
        cur = next_level
        level_id += 1
        if max_depth is not None and level_id >= max_depth:
            break

    return K, levels, sims, p



def topk_labels_by_avg(
    doc_vec, label_emb, adj_upper, root„Ñ¥=(0,), beam=5, per_parent="l+2", k=5, **kw
):
    """ÌèâÍ∑† Ï†êÏàò Í∏∞Î∞ò ÏµúÏ¢Ö ÏÉÅÏúÑ k ÎùºÎ≤®(Î£®Ìä∏ Ï†úÏô∏)."""
    K, levels, sims, p = hierarchical_beam_similarity_avg(
        doc_vec, label_emb, adj_upper, root=list(roots), beam=beam, per_parent=per_parent, **kw
    )
    root_set = set(roots)
    order = np.argsort(-K)
    order = [i for i in order if i not in root_set and np.isfinite(K[i])]
    top = order[:k]
    return top, K[top]

In [8]:
"""
Self-training pipeline with hierarchical silver labeling and dynamic dataloaders.

- Reads document/label embeddings CSVs (first column "id", rest feat000..feat127)
- Reads upper-triangular adjacency (A[i,j]=1 means i->j)
- Makes initial silver labels via hierarchical beam search (average score)
- Splits into train/val on silver set; keeps the rest as unlabeled pool
- Trains a multi-label classifier (Linear/MLP) with BCEWithLogitsLoss
- Each epoch, pseudo-labels unlabeled docs whose predicted probs exceed a threshold
- Adds them to the training set (up to top_k per doc), with patience-based early stopping

Run example
-----------
python self_training_pipeline.py \
  --doc_csv docs.csv \
  --label_csv labels.csv \
  --adj adj.npy \
  --val_ratio 0.2 --epochs 50 --patience 5 \
  --silver_threshold 0.60 --silver_topk 3 --beam 5 --tau 0.35 --root_id 0 \
  --pseudo_threshold 0.70 --pseudo_topk 3 --batch_size 256 --lr 1e-3
"""
from __future__ import annotations
import argparse
from dataclasses import dataclass
from pathlib import Path
from typing import List, Tuple, Dict

import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
def load_embeddings_csv(path: str | Path, id_col: str = "id") -> Tuple[List[int], np.ndarray]:
    """Load embeddings from CSV where the first column is an id and the rest are feature columns.
    Returns (ids, float32 matrix).
    """
    df = pd.read_csv(path)
    cols = list(df.columns)
    if id_col in df.columns:
        id_series = df[id_col]
        X = df.drop(columns=[id_col])
    else:
        # Fallback: use the first column as id
        id_series = df.iloc[:, 0]
        X = df.iloc[:, 1:]
    ids = id_series.astype(int).tolist()
    X = X.to_numpy(dtype=np.float32)
    return ids, X


# ----------------------------- Datasets -----------------------------

class MultiLabelDataset(Dataset):
    def __init__(self, X: np.ndarray, Y: np.ndarray, indices: List[int] | None = None):
        self.X = X
        self.Y = Y
        self.indices = np.array(indices if indices is not None else np.arange(X.shape[0]), dtype=np.int64)
    def __len__(self):
        return self.indices.shape[0]
    def __getitem__(self, idx: int):
        i = int(self.indices[idx])
        x = torch.from_numpy(self.X[i])
        y = torch.from_numpy(self.Y[i])
        return x, y

class UnlabeledDataset(Dataset):
    def __init__(self, X: np.ndarray, indices: List[int]):
        self.X = X
        self.indices = np.array(indices, dtype=np.int64)
    def __len__(self):
        return self.indices.shape[0]
    def __getitem__(self, idx: int):
        i = int(self.indices[idx])
        x = torch.from_numpy(self.X[i])
        return x, i

# ----------------------------- Model -----------------------------

class MLPHead(nn.Module):
    def __init__(self, in_dim: int, out_dim: int, hidden: int | None = 256, dropout: float = 0.1):
        super().__init__()
        if hidden is None or hidden <= 0:
            self.net = nn.Sequential(
                nn.LayerNorm(in_dim),
                nn.Linear(in_dim, out_dim),
            )
        else:
            self.net = nn.Sequential(
                nn.LayerNorm(in_dim),
                nn.Linear(in_dim, hidden),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Linear(hidden, out_dim),
            )
    def forward(self, x):
        return self.net(x)

# ----------------------------- Utils -----------------------------

def to_device(batch, device):
    if isinstance(batch, (tuple, list)):
        return [b.to(device) if torch.is_tensor(b) else b for b in batch]
    return batch.to(device)


def micro_f1(y_true: np.ndarray, y_prob: np.ndarray, thr: float = 0.5, eps: float = 1e-9) -> float:
    y_pred = (y_prob >= thr).astype(np.float32)
    tp = (y_true * y_pred).sum()
    fp = ((1 - y_true) * y_pred).sum()
    fn = (y_true * (1 - y_pred)).sum()
    prec = tp / (tp + fp + eps)
    rec = tp / (tp + fn + eps)
    f1 = 2 * prec * rec / (prec + rec + eps)
    return float(f1)

# -------- Initial silver labeling (no CSV save; in-memory) --------
def make_initial_silver_hier(
    docs: np.ndarray,
    labels: np.ndarray,
    adj: np.ndarray,
    roots: list[int] = [0],
    silver_threshold: float = 0.6,    # Ïù¥Í±¥ avg(K) Í∏∞Ï§Ä
    silver_topk: int = 3,
    beam: int = 5,
    per_parent: str | int = "l+2",
    tau: float = 0.35,
) -> list[list[int]]:
    """
    Í≥ÑÏ∏µ Îπî ÏÑúÏπòÎ°ú Í∞Å Î¨∏ÏÑúÏùò ÎùºÎ≤® ÌõÑÎ≥¥Î•º ÎΩëÎäîÎã§.
    - Í≥ÑÏ∏µ Î∞ñ ÎùºÎ≤®ÏùÄ Ïï†Ï¥àÏóê Ïïà Îì§Ïñ¥Ïò¥
    - Î£®Ìä∏Îì§ÏùÄ Í≤∞Í≥ºÏóêÏÑú Ï†úÏô∏
    - K(Í≤ΩÎ°ú ÌèâÍ∑†) >= silver_threshold Ïù∏ Ïï†Îì§ Ï§ë top-k
    """
    N = labels.shape[0]
    silver: list[list[int]] = []
    root_set = set(roots)

    for d in docs:
        K, levels, sims, p = hierarchical_beam_similarity_avg(
            d, labels, adj,
            roots=roots,
            beam=beam,
            per_parent=per_parent,
            tau=tau,
            normalize=False,   # ÎÑà ÏûÑÎ≤†Îî©Ïù¥ Ïù¥ÎØ∏ L2ÎùºÎ©¥ False
        )
        # ÌèâÍ∑† Ï†êÏàòÎ°ú Ï†ïÎ†¨
        order = np.argsort(-K)
        # Î£®Ìä∏Îäî Ï†úÏô∏, Ïú†ÌïúÌïú Í≤ÉÎßå
        order = [i for i in order if (i not in root_set) and np.isfinite(K[i])]
        # threshold ÌÜµÍ≥ºÌïú Í≤ÉÎßå
        cand = [i for i in order if K[i] >= silver_threshold]
        selected = cand[:silver_topk]
        silver.append(selected)

    return silver

def make_initial_silver(
    docs: np.ndarray,
    labels: np.ndarray,
    adj: np.ndarray,              # Ïù¥Ï†ú Ïïà ÏîÄ (Ìò∏ÌôòÏö©ÏúºÎ°úÎßå Îë†)
    silver_threshold: float = 0.9,
    silver_topk: int = 3,
    beam: int = 5,                # Ïù¥Ï†ú Ïïà ÏîÄ
    tau: float = 0.35,
    root_id: int = 0,
) -> List[List[int]]:
    """
    Î¨∏ÏÑúÎßàÎã§ Ï†Ñ ÎùºÎ≤® ÏûÑÎ≤†Îî©Í≥ºÏùò Ïú†ÏÇ¨ÎèÑÎ•º Î≥¥Í≥† Ï¥àÍ∏∞ silver labelÏùÑ ÎßåÎì†Îã§.
    - Ìä∏Î¶¨/Í≤ΩÎ°ú ÌÉêÏÉâ Ïïà Ìï®
    - root_idÎäî Í≤∞Í≥ºÏóêÏÑú Ï†úÏô∏
    - p >= silver_thresholdÏù∏ ÎùºÎ≤® Ï§ëÏóêÏÑú ÏÉÅÏúÑ silver_topkÎßå ÎÇ®ÍπÄ
    """
    N = labels.shape[0]
    silver: List[List[int]] = []

    for d in docs:
        # Î¨∏ÏÑú vs Î™®Îì† ÎùºÎ≤® Ï†êÏàò
        sims, p = all_label_similarity(d, labels, tau=tau, normalize=True)

        # threshold ÌÜµÍ≥º + root Ï†úÏô∏
        cand = [
            (i, float(p[i]))
            for i in range(N)
            if i != root_id and np.isfinite(p[i]) and p[i] >= silver_threshold
        ]

        # Ï†êÏàò ÎÜíÏùÄ Ïàú
        cand.sort(key=lambda x: x[1], reverse=True)

        # label indexÎßå Ï∂îÏ∂ú
        selected = [i for i, _ in cand[:silver_topk]]
        silver.append(selected)

    return silver
# ------------------------ Train / Self-Training ------------------------

def train_epoch(model, loader, optim, device, criterion):
    model.train()
    total = 0.0
    for x, y in loader:
        x, y = to_device(x, device), to_device(y, device)
        optim.zero_grad(set_to_none=True)
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optim.step()
        total += float(loss.detach().cpu().item()) * x.size(0)
    return total / max(1, len(loader.dataset))


def eval_epoch(model, loader, device, criterion, thr=0.5):
    model.eval()
    total = 0.0
    ys = []
    ps = []
    with torch.no_grad():
        for x, y in loader:
            x, y = to_device(x, device), to_device(y, device)
            logits = model(x)
            loss = criterion(logits, y)
            total += float(loss.detach().cpu().item()) * x.size(0)
            prob = torch.sigmoid(logits).detach().cpu().numpy()
            ys.append(y.detach().cpu().numpy())
            ps.append(prob)
    y_true = np.concatenate(ys, axis=0)
    y_prob = np.concatenate(ps, axis=0)
    f1 = micro_f1(y_true, y_prob, thr=thr)
    return total / max(1, len(loader.dataset)), f1, y_prob


def pseudo_label_and_grow(model, unl_ds: UnlabeledDataset,
                          num_labels: int,
                          pseudo_threshold: float = 0.9, pseudo_topk: int = 3,
                          device: str = "cpu", batch_size: int = 512):
    """Infer on unlabeled, select labels with prob>=threshold (top-k), and return new_indices and Y matrix."""
    if len(unl_ds) == 0:
        return [], np.zeros((0, num_labels), dtype=np.float32)
    loader = DataLoader(unl_ds, batch_size=batch_size, shuffle=False)
    model.eval()
    all_idx: List[int] = []
    all_y: List[np.ndarray] = []
    with torch.no_grad():
        for xb, idxs in loader:
            xb = xb.to(device)
            logits = model(xb)
            prob = torch.sigmoid(logits).detach().cpu().numpy()
            for p, i in zip(prob, idxs.numpy().tolist()):
                sel = np.flatnonzero(p >= pseudo_threshold)
                if sel.size > 0:
                    # keep at most top-k by prob
                    if sel.size > pseudo_topk:
                        top = np.argpartition(-p[sel], pseudo_topk - 1)[:pseudo_topk]
                        sel = sel[top]
                    y = np.zeros(num_labels, dtype=np.float32)
                    y[sel] = 1.0
                    all_idx.append(int(i))
                    all_y.append(y)
    if len(all_idx) == 0:
        return [], np.zeros((0, num_labels), dtype=np.float32)
    Y_new = np.stack(all_y, axis=0)
    return all_idx, Y_new






In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader
import copy



doc_ids = np.arange(len(doc_embeddings), dtype=np.int64)   # 0..num_docs-1
X = doc_embeddings.astype(np.float32)                      # [num_docs, d_doc]

# ÎùºÎ≤® ÏûÑÎ≤†Îî© ÏÑ∏ÌåÖ
label_ids = np.arange(len(label_keys), dtype=np.int64)     # 0..530
L = np.vstack([label_embeddings[k] for k in label_keys]).astype(np.float32)   # [531, d_label]


# 1) ÎùºÎ≤® ÏàúÏÑúÏôÄ B(Î∂ÄÎ™®->ÏûêÏãù) ÎßûÏ∂îÍ∏∞
order = np.argsort(label_ids)
label_ids = [label_ids[i] for i in order]
L = L[order]
assert B.shape == (L.shape[0], L.shape[0]), "Adjacency/label size mismatch"

# 2) Í≥ÑÏ∏µ silver ÎßåÎì§Í∏∞
silver = make_initial_silver_hier(
    X,          # docs (N, d)
    L,          # label_emb (C, d)
    B,          # upper adj (C, C)
    roots=roots,
    silver_threshold=0.6,
    silver_topk=3,
    beam=5,
    per_parent="l+2",
    tau=0.35,
)


# -------------------------------------------------
# 1) Í≥ÑÏ∏µ Ï†ïÎ≥¥ÏóêÏÑú parents / children ÎΩëÍ∏∞
#    B[parent, child] = 1 Ïù¥ÎùºÍ≥† ÌñàÏúºÎãàÍπå Í∑∏ÎåÄÎ°ú ÏîÄ
# -------------------------------------------------
# B: [C, C] (parent -> child)
def build_parents_children(adj):
    C = adj.shape[0]
    parents = [np.flatnonzero(adj[:, j]).astype(np.int64) for j in range(C)]
    children = [np.flatnonzero(adj[j]).astype(np.int64) for j in range(C)]
    return parents, children

parents, children = build_parents_children(B)


# -------------------------------------------------
# 2) silver ‚Üí Í≥ÑÏ∏µ pos/neg ÎßàÏä§ÌÅ¨Î°ú Î≥ÄÌôò
# -------------------------------------------------
def build_pos_neg_masks(silver, parents, children, num_labels):
    """
    silver: list[list[int]]  # Î¨∏ÏÑúÎßàÎã§ core label indexÎì§
    parents / children: list[np.ndarray]
    return:
      pos_masks: np.array [N_docs, C]
      neg_masks: np.array [N_docs, C]
    """
    N = len(silver)
    C = num_labels
    pos_masks = np.zeros((N, C), dtype=np.float32)
    neg_masks = np.zeros((N, C), dtype=np.float32)

    all_idx = np.arange(C)

    for i, core in enumerate(silver):
        core = list(core)
        # 1) coreÏùò Î∂ÄÎ™®ÍπåÏßÄ positive
        pos_set = set(core)
        for c in core:
            for p in parents[c]:
                pos_set.add(int(p))

        # 2) childrenÏùÄ ÎÇòÏ§ëÏóê negativeÏóêÏÑú Ï†úÏô∏
        child_set = set()
        for c in core:
            for ch in children[c]:
                child_set.add(int(ch))

        # pos ÎßàÏä§ÌÅ¨
        for p in pos_set:
            pos_masks[i, p] = 1.0

        # neg = Ï†ÑÏ≤¥ - pos - children
        for j in all_idx:
            if j in pos_set:
                continue
            if j in child_set:
                continue
            neg_masks[i, j] = 1.0

    return pos_masks, neg_masks

# -------------------------------------------------
# 3) Dataset: Î¨∏ÏÑú ÏûÑÎ≤†Îî© + pos/neg ÎßàÏä§ÌÅ¨
# -------------------------------------------------
class HierMultiLabelDataset(Dataset):
    def __init__(self, X, pos_masks, neg_masks, indices=None):
        self.X = X.astype(np.float32)
        self.pos = pos_masks.astype(np.float32)
        self.neg = neg_masks.astype(np.float32)
        if indices is None:
            self.indices = np.arange(self.X.shape[0], dtype=np.int64)
        else:
            self.indices = np.array(indices, dtype=np.int64)

    def __len__(self):
        return self.indices.shape[0]

    def __getitem__(self, idx):
        i = int(self.indices[idx])
        x = torch.from_numpy(self.X[i])
        pos = torch.from_numpy(self.pos[i])
        neg = torch.from_numpy(self.neg[i])
        return x, pos, neg

class UnlabeledDataset(Dataset):
    def __init__(self, X, indices):
        self.X = X.astype(np.float32)
        self.indices = np.array(indices, dtype=np.int64)
    def __len__(self):
        return len(self.indices)
    def __getitem__(self, idx):
        i = int(self.indices[idx])
        return torch.from_numpy(self.X[i]), i


# -------------------------------------------------
# 4) Bilinear classifier
#    doc_emb: [B, d_doc]
#    label_emb: [C, d_lab]  (ÎØ∏Î¶¨ GATÎ°ú ÎßåÎì† Í±∞)
#    Ï†êÏàò: doc @ W @ label_emb^T
# -------------------------------------------------
class BilinearHierClassifier(nn.Module):
    def __init__(self, doc_dim, label_emb, hidden_dim=None):
        super().__init__()
        # label_embÎäî ÌååÎùºÎØ∏ÌÑ∞Î°ú Îì§Í≥†ÏûàÎêò, ÏóÖÎç∞Ïù¥Ìä∏ Ïïà ÌïúÎã§Í≥† Í∞ÄÏ†ï(ÏõêÌïòÎ©¥ nn.ParameterÎ°ú)
        self.register_buffer("label_emb", torch.tensor(label_emb, dtype=torch.float32))
        C, d_lab = self.label_emb.shape
        self.doc_dim = doc_dim
        self.label_dim = d_lab

        if hidden_dim is None:
            # Î∞îÎ°ú doc_dim -> label_dim
            self.interaction = nn.Linear(doc_dim, d_lab, bias=False)
            self.proj = None
        else:
            # doc_dim -> hidden -> label_dim Í∞ôÏùÄ Í≤ÉÎèÑ Í∞ÄÎä•
            self.interaction = nn.Sequential(
                nn.Linear(doc_dim, hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, d_lab, bias=False),
            )

    def forward(self, x):
        """
        x: [B, d_doc]
        return: logits [B, C]
        """
        # x -> same dim as label
        h = self.interaction(x)                             # [B, d_lab]
        # [B, d_lab] @ [d_lab, C] -> [B, C]
        logits = torch.matmul(h, self.label_emb.t())
        return logits

# -------------------------------------------------
# 5) loss: Í≥ÑÏ∏µ ÎßàÏä§ÌÅ¨Î•º ÏîåÏö¥ BCE
# -------------------------------------------------
def hierarchical_bce_loss(logits, pos_mask, neg_mask):
    # logits: [B, C]
    # pos_mask, neg_mask: [B, C]
    loss_pos = -(pos_mask * F.logsigmoid(logits)).sum()
    loss_neg = -(neg_mask * F.logsigmoid(-logits)).sum()
    denom = (pos_mask.sum() + neg_mask.sum()).clamp(min=1.0)
    return (loss_pos + loss_neg) / denom

# -------------------------------------------------
# 6) ÌïôÏäµ Î£®ÌîÑ ÏòàÏãú
# -------------------------------------------------
# Ïù¥ÎØ∏ ÏûàÎäî Í≤ÉÎì§: X (Î¨∏ÏÑú BERT ÏûÑÎ≤†Îî©) : [N_docs, d_doc]
#                  L (ÎùºÎ≤® GAT ÏûÑÎ≤†Îî©)  : [C, d_lab]
#                  B_adj (Î∂ÄÎ™®->ÏûêÏãù)   : [C, C]
#                  silver (list[list[int]]) : Î¨∏ÏÑúÎ≥Ñ core label index
def train_epoch_hier(model, loader, opt, device):
    model.train()
    total = 0.0
    for xb, posb, negb in loader:
        xb = xb.to(device)
        posb = posb.to(device)
        negb = negb.to(device)
        logits = model(xb)
        loss = hierarchical_bce_loss(logits, posb, negb)
        opt.zero_grad()
        loss.backward()
        @torch.no_grad()
        def update_ema(student, teacher, m):
            for p_s, p_t in zip(student.parameters(), teacher.parameters()):
                # teacher = m * teacher + (1-m) * student
                p_t.data.mul_(m).add_(p_s.data, alpha=1.0 - m)
        opt.step()
        update_ema(model, model_ema, ema_momentum)
        total += loss.item() * xb.size(0)
    return total / len(loader.dataset)

# 1) micro F1 Í≥ÑÏÇ∞
def micro_f1_from_logits(logits, pos_mask, thr=0.5, eps=1e-9):
    """
    logits: [B, C]
    pos_mask: [B, C]  (1: positive, 0: else)
    """
    probs = torch.sigmoid(logits)
    preds = (probs >= thr).float()

    y_true = pos_mask
    y_pred = preds

    tp = (y_true * y_pred).sum()
    fp = ((1 - y_true) * y_pred).sum()
    fn = (y_true * (1 - y_pred)).sum()

    precision = tp / (tp + fp + eps)
    recall    = tp / (tp + fn + eps)
    f1 = 2 * precision * recall / (precision + recall + eps)
    return f1.item()

# 2) eval Ìï®Ïàò ÏàòÏ†ï: loss + f1 Îëò Îã§
def eval_epoch_hier(model, loader, device, k=3, thr=None):
    model.eval()
    total_loss = 0.0
    f1_list = []
    with torch.no_grad():
        for xb, posb, negb in loader:
            xb = xb.to(device)
            posb = posb.to(device)
            negb = negb.to(device)

            logits = model(xb)
            loss = hierarchical_bce_loss(logits, posb, negb)  # ÏúÑÏóê Î∞îÍæº Î≤ÑÏ†Ñ
            total_loss += loss.item() * xb.size(0)

            probs = torch.sigmoid(logits)

            if thr is not None:
                pred = (probs >= thr).float()
            else:
                # top-k Î∞©Ïãù
                B, C = probs.shape
                pred = torch.zeros_like(probs)
                topk = probs.topk(k, dim=1).indices
                pred.scatter_(1, topk, 1.0)

            # micro-f1
            y_true = posb
            y_pred = pred
            tp = (y_true * y_pred).sum().item()
            fp = ((1 - y_true) * y_pred).sum().item()
            fn = (y_true * (1 - y_pred)).sum().item()
            prec = tp / (tp + fp + 1e-9)
            rec  = tp / (tp + fn + 1e-9)
            f1   = 2 * prec * rec / (prec + rec + 1e-9)
            f1_list.append(f1)

    avg_loss = total_loss / len(loader.dataset)
    avg_f1 = float(np.mean(f1_list)) if f1_list else 0.0
    return avg_loss, avg_f1
def pseudo_label_and_grow_hier(
    model,
    unl_ds,             # UnlabeledDataset
    X_all,              # Ï†ÑÏ≤¥ Î¨∏ÏÑú ÏûÑÎ≤†Îî© (numpy)
    parents, children,
    num_labels,
    device,
    pseudo_threshold=0.45,
    pseudo_topk=3,
    batch_size=512,
):
    if len(unl_ds) == 0:
        return [], None, None

    loader = DataLoader(unl_ds, batch_size=batch_size, shuffle=False)
    model.eval()
    new_idx = []
    new_pos_list = []
    new_neg_list = []

    with torch.no_grad():
        for xb, idxs in loader:
            xb = xb.to(device)
            logits = model(xb)
            prob = torch.sigmoid(logits).cpu().numpy()

            for p, i_doc in zip(prob, idxs.numpy().tolist()):
                order = np.argsort(-p)
                top1 = p[order[0]]
                # 1) top-1Ïù¥ thresholdÎ•º Î™ª ÎÑòÏúºÎ©¥ Í∑∏ÎÉ• Î≤ÑÎ¶∞Îã§
                if top1 < pseudo_threshold:
                    continue
                core = [j for j in order if p[j] >= pseudo_threshold][:pseudo_topk]
                if len(core) == 0:
                    # ÏïÑÏòà Ïù¥ Î¨∏ÏÑúÎäî Ïù¥Î≤à epochÏóê Ïïà ÎÑ£Ïùå
                    continue

                # Í≥ÑÏ∏µ pos/neg Íµ¨ÏÑ±
                pos = set(core)
                for c in core:
                    for pa in parents[c]:
                        pos.add(int(pa))
                child = set()
                for c in core:
                    for ch in children[c]:
                        child.add(int(ch))

                pos_mask = np.zeros(num_labels, dtype=np.float32)
                neg_mask = np.zeros(num_labels, dtype=np.float32)
                for j in pos:
                    pos_mask[j] = 1.0
                for j in range(num_labels):
                    if j in pos:    # Ïù¥ÎØ∏ ÏñëÏÑ±
                        continue
                    if j in child:  # Î™®Î•¥Í≤†Ïùå ‚Üí negativeÏóêÏÑú Ï†úÏô∏
                        continue
                    neg_mask[j] = 1.0

                new_idx.append(int(i_doc))
                new_pos_list.append(pos_mask)
                new_neg_list.append(neg_mask)



    if len(new_idx) == 0:
        return [], None, None

    new_pos = np.stack(new_pos_list, axis=0)
    new_neg = np.stack(new_neg_list, axis=0)
    return new_idx, new_pos, new_neg



device = "cuda" if torch.cuda.is_available() else "cpu"
has_silver = np.array([len(lbls) > 0 for lbls in silver], dtype=bool)
N_docs = X.shape[0]
C = L.shape[0]

# silver ÏûàÎäî Î¨∏ÏÑú / ÏóÜÎäî Î¨∏ÏÑú
has_silver = np.array([len(lbls) > 0 for lbls in silver], dtype=bool)
idx_silver = np.flatnonzero(has_silver)      # Ïó¨Í∏∞Í∞Ä train/val ÌõÑÎ≥¥
idx_unl    = np.flatnonzero(~has_silver)     # ÏßÑÏßú unl

print("total:", N_docs)
print("with silver:", len(idx_silver))
print("unlabeled :", len(idx_unl))

# Ïù¥Ï†ú train/valÏùÄ silver ÏûàÎäî Ïï†Îì§Îßå ÏÑûÏñ¥ÏÑú ÎÇòÎààÎã§
rng = np.random.default_rng(42)
rng.shuffle(idx_silver)
n_val = int(len(idx_silver) * 0.2)
idx_val   = idx_silver[:n_val]
idx_train = idx_silver[n_val:]

# parents, children ÎßåÎì§Í∏∞
def build_parents_children(adj):
    C = adj.shape[0]
    parents = [np.flatnonzero(adj[:, j]).astype(np.int64) for j in range(C)]
    children = [np.flatnonzero(adj[j]).astype(np.int64) for j in range(C)]
    return parents, children

parents, children = build_parents_children(B)

pos_masks = np.zeros((N_docs, C), dtype=np.float32)
neg_masks = np.zeros((N_docs, C), dtype=np.float32)

for i in idx_silver:  # silver ÏûàÎäî Ïï†Îßå ÎèàÎã§
    core = silver[i]

    # 1) core + parents
    pos = set(core)
    for c in core:
        for p in parents[c]:
            pos.add(int(p))

    # 2) childrenÏùÄ Î™®Î¶Ñ
    child = set()
    for c in core:
        for ch in children[c]:
            child.add(int(ch))

    for p in pos:
        pos_masks[i, p] = 1.0

    for j in range(C):
        if j in pos:      # Ïù¥ÎØ∏ ÏñëÏÑ±
            continue
        if j in child:    # Î™®Î¶Ñ
            continue
        neg_masks[i, j] = 1.0



train_ds = HierMultiLabelDataset(X, pos_masks, neg_masks, indices=idx_train)
val_ds   = HierMultiLabelDataset(X, pos_masks, neg_masks, indices=idx_val) if len(idx_val) > 0 else None
unl_ds   = UnlabeledDataset(X, idx_unl.tolist())
print(len(train_ds),len(val_ds),len(unl_ds))

train_loader = DataLoader(train_ds, batch_size=256, shuffle=True)
val_loader   = DataLoader(val_ds, batch_size=256, shuffle=False)

model = BilinearHierClassifier(doc_dim=X.shape[1], label_emb=L, hidden_dim=256).to(device)
model_ema = copy.deepcopy(model).to(device)
for p in model_ema.parameters():
    p.requires_grad = False  # teacherÎäî gradient Ïïà Î∞õÏùå

ema_momentum = 0.99  # ÎòêÎäî 0.997, 0.999 Îì±
opt = torch.optim.AdamW(model.parameters(), lr=1e-3)
epochs = 500

N_labels = L.shape[0]
best_f1 = -1.0
patience = 20
no_improve = 0
warmup_self = 1   # 1 epochÏùÄ self-training Ïïà ÌïòÍ≤å Ìï¥ÏÑú Ìïú Î≤à ÏïàÏ†ïÌôî

for epoch in range(1, epochs + 1):
    # train
    tr_loss = train_epoch_hier(model, train_loader, opt, device)

    # val: f1 Í∏∞Ï§Ä
    if val_loader is not None and len(val_ds) > 0:
        va_loss, va_f1 = eval_epoch_hier(model, val_loader, device, k=3)
        print(f"Epoch {epoch:03d} | train_loss={tr_loss:.3f}  val_loss={va_loss:.3f}  val_f1={va_f1:.3f}")

        # early stoppingÏùÑ f1Î°ú
        if va_f1 > best_f1 + 1e-6:
            best_f1 = va_f1
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= patience:
                print(f"Early stopping at epoch {epoch} (best f1={best_f1:.4f})")
                break
    else:
        print(f"Epoch {epoch:03d} | train_loss={tr_loss:.3f}")

    # self-training: 1ÏóêÌè≠Ïóê Ï†ÑÎ∂Ä Îì§Ïñ¥Í∞ÄÎäî Í±∞ Î∞©ÏßÄÏö©ÏúºÎ°ú warmup ÎÑ£Ïùå
    if epoch <= warmup_self:
        print("  + (skip pseudo-labeling on warmup epoch)")
        continue

    new_idx, new_pos, new_neg = pseudo_label_and_grow_hier(
        model,
        unl_ds,
        X,
        parents,
        children,
        C,                   # num_labels
        device=device,
        pseudo_threshold=0.45,
        pseudo_topk=3,
        batch_size=512,
    )

    if len(new_idx) > 0:
        # Ï†ÑÏó≠ ÎßàÏä§ÌÅ¨ Í∞±Ïã†
        pos_masks[new_idx] = new_pos
        neg_masks[new_idx] = new_neg

        # unlÏóêÏÑú Ï†úÍ±∞
        keep_mask = ~np.isin(unl_ds.indices, np.array(new_idx, dtype=np.int64))
        unl_ds.indices = unl_ds.indices[keep_mask]

        # trainÏóê Ï∂îÍ∞Ä
        train_ds.indices = np.concatenate([train_ds.indices, np.array(new_idx, dtype=np.int64)])
        train_loader = DataLoader(train_ds, batch_size=256, shuffle=True, drop_last=False)

        print(f"  + Added {len(new_idx)} pseudo-labeled docs (unl pool ‚Üí {len(unl_ds)} left)")
    else:
        print("  + No pseudo-labeled docs added this epoch")


total: 29487
with silver: 12014
unlabeled : 17473
9612 2402 17473
Epoch 001 | train_loss=0.446  val_loss=0.115  val_f1=0.006
  + (skip pseudo-labeling on warmup epoch)
Epoch 002 | train_loss=0.060  val_loss=0.042  val_f1=0.150
  + No pseudo-labeled docs added this epoch
Epoch 003 | train_loss=0.039  val_loss=0.037  val_f1=0.150
  + No pseudo-labeled docs added this epoch
Epoch 004 | train_loss=0.036  val_loss=0.034  val_f1=0.165
  + No pseudo-labeled docs added this epoch
Epoch 005 | train_loss=0.034  val_loss=0.033  val_f1=0.165
  + No pseudo-labeled docs added this epoch
Epoch 006 | train_loss=0.033  val_loss=0.033  val_f1=0.165
  + No pseudo-labeled docs added this epoch
Epoch 007 | train_loss=0.032  val_loss=0.032  val_f1=0.165
  + No pseudo-labeled docs added this epoch
Epoch 008 | train_loss=0.032  val_loss=0.032  val_f1=0.193
  + No pseudo-labeled docs added this epoch
Epoch 009 | train_loss=0.032  val_loss=0.031  val_f1=0.166
  + No pseudo-labeled docs added this epoch
Epoch 01

In [14]:
import csv, os
from pathlib import Path
import numpy as np
import pandas as pd

# ------------ Paths (edit if needed) ------------
TEST_CORPUS = "Amazon_products/test/test_corpus.txt"   # lines: pid \t text
OUT_PATH    = "submission_bda.csv"
# ------------ Hyperparams ------------
MIN_LABS  = 2
MAX_LABS  = 3
BATCH = 1024
doc_ids, doc_texts = load_docs_txt(TEST_CORPUS)

# 2) ÎùºÎ≤® Îïå ÎßåÎì† vectorizer Ïû¨ÏÇ¨Ïö©Ìï¥ÏÑú ÏûÑÎ≤†Îî© ÎßåÎì§Í∏∞
test_embeddings = build_doc_embeddings_from_existing_vectorizer(doc_texts, vectorizer)
test_embeddings = test_embeddings.astype(np.float32)   # [num_test, d]
# load test pids
pids = doc_ids   # Ïù¥ÎØ∏ Î¨∏ÏûêÏó¥ id
if "L" in globals():
    if not isinstance(L, np.ndarray):
        # Ïòà: LÏù¥ torch.TensorÏù∏ Í≤ΩÏö∞
        L = L.detach().cpu().numpy().astype(np.float32)
else:
    raise ValueError("ÎùºÎ≤® ÏûÑÎ≤†Îî© LÏù¥ Î©îÎ™®Î¶¨Ïóê ÏóÜÏñ¥! GAT ÎÅùÎÇú Îí§Ïùò ÏûÑÎ≤†Îî©ÏùÑ LÎ°ú Îë¨Ïïº Ìï¥.")

# 5) ÎùºÎ≤® idÎäî 0..N-1Î°ú ÏÉùÏÑ± (ÎÑ§Í∞Ä ÎßêÌïú ÎåÄÎ°ú adjacencyÎûë ÏàúÏÑúÍ∞Ä Ïù¥ÎØ∏ ÎßûÎã§Í≥† ÌñàÏúºÎãàÍπå)
lab_ids = np.arange(L.shape[0], dtype=np.int64)

# 6) adjacencyÎèÑ Î©îÎ™®Î¶¨Ïóê ÏûàÎäî Í±∏ Í∑∏ÎåÄÎ°ú Ïì¥Îã§
#    Ïó¨Í∏∞ÏÑú AÎäî 531x531 Í∞ôÏùÄ numpy arrayÎùºÍ≥† Í∞ÄÏ†ï
assert B.shape == (L.shape[0], L.shape[0]), "Adjacency/label size mismatch"

# 7) children Î¶¨Ïä§Ìä∏ ÎØ∏Î¶¨ ÎßåÎì§Ïñ¥ÎëêÍ∏∞
children = [np.flatnonzero(B[i]) for i in range(B.shape[0])]



In [15]:
model.eval()

def ancestors_of(node, adj):
    # adj[parent, child] = 1 Í∞ÄÏ†ï
    parents = np.flatnonzero(adj[:, node])  # (N,)
    return parents.tolist()

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

IN_DIM = test_embeddings.shape[1]
missing = 0  # ÏßÄÍ∏àÏùÄ Ïì∏ Ïùº ÏóÜÏßÄÎßå ÏõêÎûò ÏΩîÎìúÎûë ÌòïÌÉú ÎßûÏ∂∞Îë†

with open(OUT_PATH, "w", newline="", encoding="utf-8") as f:
    w = csv.writer(f)
    w.writerow(["id", "label"])

    buf_x, buf_pid = [], []

    def flush():
        if not buf_x:
            return
        xb = torch.from_numpy(np.stack(buf_x, axis=0).astype(np.float32)).to(device)
        with torch.inference_mode():
            prob = torch.sigmoid(model(xb)).detach().cpu().numpy()
        prob = np.nan_to_num(prob, nan=-1.0, posinf=1.0, neginf=0.0)

        for pid, p in zip(buf_pid, prob):
            order = np.argsort(-p)

            # 1) Í∏∞Î≥∏ ÌõÑÎ≥¥ ÎΩëÍ∏∞
            thr_keep = [i for i in order if p[i] >= 0.5][:MAX_LABS]
            if len(thr_keep) >= MIN_LABS:
                keep = thr_keep[:MAX_LABS]
            else:
                keep = order[:max(MIN_LABS, len(thr_keep))]
                if len(keep) < MIN_LABS:
                    keep = order[:MIN_LABS]

            # 2) Î∂ÄÎ™® ÌõÑÎ≥¥
            parent_cands = []
            for c in keep:
                pars = ancestors_of(c, B)
                for pa in pars:
                    if pa not in keep and pa not in parent_cands:
                        parent_cands.append(pa)

            parent_cands.sort(key=lambda idx: p[idx], reverse=True)

            # 3) ÎÇ®Îäî Ïä¨Î°Ø Î∂ÄÎ™®Î°ú Ï±ÑÏö∞Í∏∞
            final_idxs = list(keep)
            for pa in parent_cands:
                if len(final_idxs) >= MAX_LABS:
                    break
                final_idxs.append(pa)

            # 4) Í∑∏ÎûòÎèÑ Î™®ÏûêÎùºÎ©¥ ÌôïÎ•†Ïàú
            if len(final_idxs) < MIN_LABS:
                for idx in order:
                    if idx not in final_idxs:
                        final_idxs.append(idx)
                    if len(final_idxs) >= MIN_LABS:
                        break

            labels = sorted(int(lab_ids[i]) for i in final_idxs)
            w.writerow([pid, ",".join(map(str, labels))])

        buf_x.clear()
        buf_pid.clear()

    # Ïó¨Í∏∞ÏÑú Î∞îÎ°ú pidsÏôÄ test_embeddingsÎ•º Í∞ôÏù¥ ÏàúÌöå
    for pid, emb in zip(pids, test_embeddings):
        x = emb
        if x.dtype != np.float32:
            x = x.astype(np.float32, copy=False)
        buf_x.append(x)
        buf_pid.append(pid)
        if len(buf_x) >= BATCH:
            flush()
    flush()

print(f"Saved: {OUT_PATH} | samples={len(pids)} | min-max labels per sample={MIN_LABS}-{MAX_LABS} | missing_pids={missing}")


Saved: submission_bda.csv | samples=19658 | min-max labels per sample=2-3 | missing_pids=0


In [89]:
"""
# ------------------------
# Dummy baseline for Kaggle submission
# Generates random multi-label predictions
# ------------------------
import os
import csv
import random
from tqdm import tqdm

# --- Paths ---
TEST_DIR = "Amazon_products/test"  # modify if needed
TEST_CORPUS_PATH = os.path.join(TEST_DIR, "test_corpus.txt")  # product_id \t text
SUBMISSION_PATH = "submission.csv"  # output file

# --- Constants ---
NUM_CLASSES = 531  # total number of classes (0‚Äì530)
MIN_LABELS = 1     # minimum number of labels per sample
MAX_LABELS = 3     # maximum number of labels per sample

# --- Load test corpus ---
def load_corpus(path):
    """Load test corpus into {pid: text} dictionary."""
"""
    pid2text = {}
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split("\t", 1)
            if len(parts) == 2:
                pid, text = parts
                pid2text[pid] = text
    return pid2text

pid2text_test = load_corpus(TEST_CORPUS_PATH)
pid_list_test = list(pid2text_test.keys())

# --- Generate random predictions ---
all_pids, all_labels = [], []
for pid in tqdm(pid_list_test, desc="Generating dummy predictions"):
    n_labels = random.randint(MIN_LABELS, MAX_LABELS)
    labels = random.sample(range(NUM_CLASSES), n_labels)
    labels = sorted(labels)
    all_pids.append(pid)
    all_labels.append(labels)

# --- Save submission file ---
with open(SUBMISSION_PATH, "w", newline="", encoding="utf-8") as f:
    writer = csv.writer(f)
    writer.writerow(["id", "label"])
    for pid, labels in zip(all_pids, all_labels):
        writer.writerow([pid, ",".join(map(str, labels))])

print(f"Dummy submission file saved to: {SUBMISSION_PATH}")
print(f"Total samples: {len(all_pids)}, Classes per sample: {MIN_LABELS}-{MAX_LABELS}")"""

Generating dummy predictions: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 19658/19658 [00:00<00:00, 419908.88it/s]

Dummy submission file saved to: submission.csv
Total samples: 19658, Classes per sample: 1-3



