In [1]:
# ============================================================
# üß† Train Gating Network for Sparse MoE (v3)
#  - Encoder: google/embeddinggemma-300m (SentenceTransformer)
#  - Gating: 2-layer MLP tr√™n sentence embedding
#  - M·ª•c ti√™u: ch·ªçn expert ph√π h·ª£p cho m·ªói c√¢u h·ªèi
# ============================================================

import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from sentence_transformers import SentenceTransformer

# ------------------------------------------------------------
# ‚öôÔ∏è CONFIG
# ------------------------------------------------------------
DATA_FILE = "gating_train.csv"              # merged dataset cho t·∫•t c·∫£ expert
MODEL_NAME = "google/embeddinggemma-300m"   # c√πng encoder v·ªõi Dual Expert
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

BATCH_SIZE = 32
EPOCHS = 10
LR = 2e-5
SEED = 42

SAVE_DIR = "models_gating_v3"
os.makedirs(SAVE_DIR, exist_ok=True)

random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# ------------------------------------------------------------
# üß© Dataset
#   - question: str
#   - expert_label: t√™n expert (Advice / Treatment / Disease / Topic / ...)
# ------------------------------------------------------------
class GatingDataset(Dataset):
    def __init__(self, df, label2id):
        self.texts = df["question"].astype(str).tolist()
        self.labels = df["expert_label"].astype(str).tolist()
        self.label2id = label2id

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label_name = self.labels[idx]

        # one-hot multi-label vector (num_experts,)
        y = torch.zeros(len(self.label2id), dtype=torch.float32)
        if label_name in self.label2id:
            y[self.label2id[label_name]] = 1.0
        return text, y


# ------------------------------------------------------------
# üîß Helper encode_texts (gi·ªëng dual expert v3)
# ------------------------------------------------------------
def encode_texts(encoder: SentenceTransformer, texts, batch_size=64):
    with torch.no_grad():
        emb = encoder.encode(
            texts,
            batch_size=batch_size,
            convert_to_tensor=True,
            device=DEVICE,
            show_progress_bar=False
        )
    return emb  # (N, D)


# ------------------------------------------------------------
# üß† Gating MLP
#   - Input: sentence embedding (D)
#   - Output: logits (num_experts) ‚Äî d√πng BCEWithLogits cho linh ho·∫°t multi-label
# ------------------------------------------------------------
class GatingMLP(nn.Module):
    def __init__(self, input_dim, num_classes, hidden_dim=512):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        h = F.relu(self.fc1(x))
        return self.fc2(h)  # (B, num_classes)


# ------------------------------------------------------------
# üöÄ Train Gating
# ------------------------------------------------------------
def train_gating():
    print("\n=== üß© Training Gating Network for Sparse MoE (v3) ===")

    # 1) Load data
    df = pd.read_csv(DATA_FILE)
    df = df.dropna(subset=["question", "expert_label"])

    # Map expert_label -> id
    experts = sorted(df["expert_label"].astype(str).unique())
    label2id = {label: i for i, label in enumerate(experts)}
    id2label = {i: label for label, i in label2id.items()}

    print(f"‚úÖ Loaded {len(df)} samples")
    print(f"‚úÖ Num experts: {len(experts)}")
    print("   Experts:", experts)

    dataset = GatingDataset(df, label2id)
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

    # 2) Base encoder (frozen)
    base_encoder = SentenceTransformer(MODEL_NAME)
    base_encoder.to(DEVICE)
    base_encoder.eval()
    for p in base_encoder.parameters():
        p.requires_grad = False

    # L·∫•y dim embedding
    try:
        input_dim = base_encoder.get_sentence_embedding_dimension()
    except Exception:
        test_emb = encode_texts(base_encoder, ["test"])
        input_dim = test_emb.size(1)

    # 3) Gating model
    num_classes = len(label2id)
    model = GatingMLP(input_dim=input_dim, num_classes=num_classes).to(DEVICE)

    optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
    loss_fn = nn.BCEWithLogitsLoss()

    use_amp = DEVICE == "cuda"
    scaler = torch.amp.GradScaler("cuda") if use_amp else None

    # 4) Training loop
    for epoch in range(EPOCHS):
        model.train()
        total_loss = 0.0
        total_acc = 0.0
        n_steps = 0

        pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{EPOCHS}")

        for batch in pbar:
            texts, y = batch
            y = y.to(DEVICE)

            if use_amp:
                ctx_mgr = torch.amp.autocast("cuda")
            else:
                from contextlib import nullcontext
                ctx_mgr = nullcontext()

            with ctx_mgr:
                # Encode texts -> embeddings
                emb = encode_texts(base_encoder, list(texts))  # (B, D)

                # Forward gating
                logits = model(emb)  # (B, num_classes)
                loss = loss_fn(logits, y)

                # metric: top-1 accuracy (argmax)
                with torch.no_grad():
                    preds = torch.argmax(logits, dim=1)
                    gold = torch.argmax(y, dim=1)
                    acc = (preds == gold).float().mean().item()

            optimizer.zero_grad()
            if use_amp:
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                optimizer.step()

            n_steps += 1
            total_loss += loss.item()
            total_acc += acc

            pbar.set_postfix(
                loss=f"{total_loss/n_steps:.4f}",
                acc=f"{total_acc/n_steps:.4f}"
            )

        print(
            f"‚úÖ Epoch {epoch+1} | "
            f"Loss={total_loss/n_steps:.4f} | "
            f"Top1 Acc={total_acc/n_steps:.4f}"
        )

    # 5) Save model + label mapping
    torch.save(model.state_dict(), os.path.join(SAVE_DIR, "gating_mlp.pt"))
    torch.save(
        {
            "label2id": label2id,
            "id2label": id2label,
            "input_dim": input_dim,
            "num_classes": num_classes,
            "model_name": MODEL_NAME,
        },
        os.path.join(SAVE_DIR, "gating_meta.pt"),
    )

    print(f"\nüíæ Saved gating model to: {SAVE_DIR}")


# ------------------------------------------------------------
if __name__ == "__main__":
    train_gating()


  from .autonotebook import tqdm as notebook_tqdm



=== üß© Training Gating Network for Sparse MoE (v3) ===
‚úÖ Loaded 227322 samples
‚úÖ Num experts: 12
   Experts: ['Advice', 'Application', 'Cause', 'Complication', 'Definition', 'Detail', 'Population', 'Prevention', 'RiskFactor', 'SubDisease', 'Symptom', 'Treatment']


Epoch 1/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7104/7104 [03:24<00:00, 34.78it/s, acc=0.7500, loss=0.2302]


‚úÖ Epoch 1 | Loss=0.2302 | Top1 Acc=0.7500


Epoch 2/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7104/7104 [03:22<00:00, 35.10it/s, acc=0.9575, loss=0.0612]


‚úÖ Epoch 2 | Loss=0.0612 | Top1 Acc=0.9575


Epoch 3/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7104/7104 [03:23<00:00, 34.99it/s, acc=0.9868, loss=0.0213]


‚úÖ Epoch 3 | Loss=0.0213 | Top1 Acc=0.9868


Epoch 4/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7104/7104 [03:25<00:00, 34.58it/s, acc=0.9963, loss=0.0091]


‚úÖ Epoch 4 | Loss=0.0091 | Top1 Acc=0.9963


Epoch 5/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7104/7104 [03:18<00:00, 35.74it/s, acc=0.9987, loss=0.0045]


‚úÖ Epoch 5 | Loss=0.0045 | Top1 Acc=0.9987


Epoch 6/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7104/7104 [03:15<00:00, 36.38it/s, acc=0.9993, loss=0.0026]


‚úÖ Epoch 6 | Loss=0.0026 | Top1 Acc=0.9993


Epoch 7/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7104/7104 [03:23<00:00, 34.83it/s, acc=0.9995, loss=0.0016]


‚úÖ Epoch 7 | Loss=0.0016 | Top1 Acc=0.9995


Epoch 8/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7104/7104 [03:25<00:00, 34.62it/s, acc=0.9997, loss=0.0011]


‚úÖ Epoch 8 | Loss=0.0011 | Top1 Acc=0.9997


Epoch 9/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7104/7104 [03:25<00:00, 34.54it/s, acc=0.9998, loss=0.0008]


‚úÖ Epoch 9 | Loss=0.0008 | Top1 Acc=0.9998


Epoch 10/10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7104/7104 [03:23<00:00, 34.95it/s, acc=0.9998, loss=0.0006]


‚úÖ Epoch 10 | Loss=0.0006 | Top1 Acc=0.9998

üíæ Saved gating model to: models_gating_v3


In [14]:
import os
import torch
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer


META_PATH = os.path.join(SAVE_DIR, "gating_meta.pt")
STATE_PATH = os.path.join(SAVE_DIR, "gating_mlp.pt")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
class GatingMLP(nn.Module):
    def __init__(self, input_dim, num_classes, hidden_dim=512):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        h = F.relu(self.fc1(x))
        return self.fc2(h)  # (B, num_classes)


def encode_texts(encoder: SentenceTransformer, texts, batch_size=64):
    with torch.no_grad():
        emb = encoder.encode(
            texts,
            batch_size=batch_size,
            convert_to_tensor=True,
            device=DEVICE,
            show_progress_bar=False
        )
    return emb  # (N, D)



meta = torch.load(META_PATH, map_location=DEVICE)
label2id = meta["label2id"]
id2label = meta["id2label"]
input_dim = meta["input_dim"]
num_classes = meta["num_classes"]
model_name = meta["model_name"]

encoder = SentenceTransformer(model_name).to(DEVICE)
encoder.eval()
for p in encoder.parameters():
    p.requires_grad = False

model = GatingMLP(input_dim=input_dim, num_classes=num_classes).to(DEVICE)
model.load_state_dict(torch.load(STATE_PATH, map_location=DEVICE))
model.eval()

QUESTION = "hiv l√† g√¨ v√† ƒë∆∞·ª£c ƒëi·ªÅu tr·ªã nh∆∞ th·∫ø n√†o"

with torch.no_grad():
    emb = encode_texts(encoder, [QUESTION])  # (1, D)
    logits = model(emb)                      # (1, num_classes)
    probs = torch.sigmoid(logits).squeeze(0) # (num_classes,)

    topk = torch.topk(probs, k=5)
    print(f"\nüß™ Question: {QUESTION}")
    print("Top experts:")
    for score, idx in zip(topk.values, topk.indices):
        name = id2label[int(idx)]
        print(f"  {name:15s} ‚Üí {float(score):.4f}")



üß™ Question: hiv l√† g√¨ v√† ƒë∆∞·ª£c ƒëi·ªÅu tr·ªã nh∆∞ th·∫ø n√†o
Top experts:
  Treatment       ‚Üí 0.9490
  Definition      ‚Üí 0.6008
  Detail          ‚Üí 0.0078
  Application     ‚Üí 0.0000
  Advice          ‚Üí 0.0000


In [17]:
import os
import torch
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
import torch.nn as nn

META_PATH = os.path.join(SAVE_DIR, "gating_meta.pt")
STATE_PATH = os.path.join(SAVE_DIR, "gating_mlp.pt")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


# -------------------------
# Gating MLP
# -------------------------
class GatingMLP(nn.Module):
    def __init__(self, input_dim, num_classes, hidden_dim=512):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        h = F.relu(self.fc1(x))
        return self.fc2(h)


# -------------------------
# Load metadata
# -------------------------
meta = torch.load(META_PATH, map_location=DEVICE)
label2id = meta["label2id"]
id2label = meta["id2label"]
input_dim = meta["input_dim"]
num_classes = meta["num_classes"]
model_name = meta["model_name"]

# -------------------------
# Load encoder (FAST MODE)
# -------------------------
encoder = SentenceTransformer(model_name, device=DEVICE)
encoder.eval()
for p in encoder.parameters():
    p.requires_grad = False


# FAST INFERENCE: encode 1 c√¢u b·∫±ng .encode() nh∆∞ng tr√°nh sync
def fast_encode(text):
    with torch.inference_mode():
        emb = encoder.encode(
            [text],
            batch_size=1,
            convert_to_tensor=True,
            device=DEVICE,
            show_progress_bar=False
        )
    return emb


# -------------------------
# Load gating model
# -------------------------
model = GatingMLP(input_dim, num_classes).to(DEVICE)
model.load_state_dict(torch.load(STATE_PATH, map_location=DEVICE))
model.eval()

# -------------------------
# Test
# -------------------------
QUESTION = "hiv l√† g√¨"

with torch.inference_mode():
    emb = fast_encode(QUESTION)
    logits = model(emb)
    probs = torch.sigmoid(logits).squeeze(0)

    topk = torch.topk(probs, k=5)

    print(f"\nüß™ Question: {QUESTION}")
    print("Top experts:")
    for score, idx in zip(topk.values, topk.indices):
        print(f"{id2label[int(idx)]:20s} ‚Üí {float(score):.4f}")



üß™ Question: hiv l√† g√¨
Top experts:
Definition           ‚Üí 0.9999
Detail               ‚Üí 0.0004
SubDisease           ‚Üí 0.0000
Complication         ‚Üí 0.0000
Application          ‚Üí 0.0000


cuda:0
