# Gating Evaluation
This notebook evaluates the accuracy of the gating network used by FinMoE

In [1]:
import torch
from pathlib import Path
from transformers import AutoTokenizer
from tqdm import tqdm

from evals import load_eval_dataset
from FinMoE import FinMoE
from utils import get_dataset_args

assert torch.cuda.is_available(), "CUDA not installed"
device = torch.device("cuda")

seed = 42
torch.manual_seed(seed)

MAX_LENGTH = 512

model_id = "meta-llama/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

hub_basepath = Path(r"C:\Users\samba\.cache\huggingface\hub")
args = get_dataset_args(tokenizer, hub_basepath)

In [3]:
ckpt_path = Path(rf"D:/models/FinMoE-v2") / "checkpoint-14360"
finMoE_model = FinMoE.load_pretrained(ckpt_path).to(device).eval()

In [None]:
dataset_ids = ["FPB", "Headline", "Topics"]

for expert_idx, dataset_id in enumerate(dataset_ids):
    testset = load_eval_dataset(tokenizer, dataset_id, args)

    correct = 0
    total = 0
    progbar = tqdm(testset)
    for example in progbar:
        input_ids = torch.tensor(example["input_ids"]).to(device)
        attn_mask = torch.tensor(example["attention_mask"]).to(device)
        gate_scores, top1_indices = finMoE_model.gate.forward(input_ids, attn_mask)

        correct += (torch.argmax(gate_scores, dim=-1) == expert_idx).sum()
        total += gate_scores.size(0)
        progbar.set_description(f"{correct/total*100:.2f}")
    
    print(f"Gate Accuracy for expert {expert_idx} {dataset_id}: {correct / total * 100:.2f}")

Loading Topics from path C:\Users\samba\.cache\huggingface\hub\datasets--Sujet--TopicClassification


Map:   0%|          | 0/850 [00:00<?, ? examples/s]

Map:   0%|          | 0/850 [00:00<?, ? examples/s]

0.00: 100%|██████████| 850/850 [00:45<00:00, 18.70it/s]

Gate Accuracy for expert 2 Topics: 0.00





### Gating network weights

In [9]:
finMoE_model.gate.w_gate.weight.sum(dim=1)

tensor([ 0.7107, -0.9033,  0.2338], device='cuda:0', grad_fn=<SumBackward1>)

In [8]:
finMoE_model.gate.w_gate.bias

Parameter containing:
tensor([-0.0235,  0.0092, -0.0074], device='cuda:0', requires_grad=True)