# Gating Evaluation
This notebook evaluates the accuracy of the gating network used by FinMoE

In [None]:
import torch
from pathlib import Path
from huggingface_hub import constants as hub_c
from transformers import AutoTokenizer
from tqdm import tqdm

from evals import load_eval_dataset
from FinMoE import FinMoE
from utils import get_dataset_args

assert torch.cuda.is_available(), "CUDA not installed"
device = torch.device("cuda")

seed = 42
torch.manual_seed(seed)

model_id = "meta-llama/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

args = get_dataset_args(tokenizer, Path(hub_c.HF_HUB_CACHE))

In [2]:
ckpt_path = Path(rf"D:/models/FinMoE-fast-gating") / "checkpoint-3590"
finMoE_model = FinMoE.load_pretrained(ckpt_path).to(device).eval()

In [None]:
dataset_ids = ["FPB", "Headline", "Topics"]

for expert_idx, dataset_id in enumerate(dataset_ids):
    testset = load_eval_dataset(tokenizer, dataset_id, args)

    correct = 0
    total = 0
    progbar = tqdm(testset)
    for example in progbar:
        input_ids = torch.tensor(example["input_ids"]).to(device)
        attn_mask = torch.tensor(example["attention_mask"]).to(device)
        gate_scores = finMoE_model.gate.forward(input_ids, attn_mask)

        correct += (torch.argmax(gate_scores, dim=-1) == expert_idx).sum()
        total += gate_scores.size(0)
        progbar.set_description(f"{correct/total*100:.2f}")
    
    print(f"Gate Accuracy for expert {expert_idx} {dataset_id}: {correct / total * 100:.2f}")