# Evaluate Router Performance (CombineQA)

In [32]:
import torch
import sys
import pandas as pd
import ast
import numpy as np

sys.path.append('../training')
from router import Router

In [3]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

In [4]:
# Load trained router
def load_router(device, model_path="router.pth"):
    router = Router(input_dim=1024, output_dim=4).to(device)
    router.load_state_dict(torch.load(model_path, map_location=device))
    router.eval()
    return router

In [5]:
# Load in MMLU question embeddings
mmlu = pd.read_csv('../training/mmlu_test_metadata.csv', usecols=['subject','question', 'embedding', 'Qwen_correct', 'MathQwen_correct', 'CodeQwen_correct', 'label'])
mmlu.head()

Unnamed: 0,question,subject,embedding,Qwen_correct,MathQwen_correct,CodeQwen_correct,label
0,Find the degree for the given field extension ...,abstract_algebra,"[-0.00849960371851921, 0.02478231117129326, -0...",0,1,0,"[0, 1, 0]"
1,"Let p = (1, 2, 5, 4)(2, 3) in S_5 . Find the i...",abstract_algebra,"[0.02112336829304695, 0.023575058206915855, -0...",0,0,0,"[0, 0, 0]"
2,Find all zeros in the indicated finite field o...,abstract_algebra,"[0.029676876962184906, 0.033119432628154755, -...",0,0,0,"[0, 0, 0]"
3,Statement 1 | A factor group of a non-Abelian ...,abstract_algebra,"[0.009602331556379795, 0.02028682269155979, -0...",0,1,0,"[0, 1, 0]"
4,Find the product of the given polynomials in t...,abstract_algebra,"[0.04502462223172188, 0.03369951620697975, -0....",1,0,0,"[1, 0, 0]"


In [6]:
all_embeds = mmlu['embedding'].apply(lambda x: torch.tensor(ast.literal_eval(x)).to(DEVICE))
all_embeds = torch.stack(list((all_embeds)))
# all_labels = mmlu['label'].apply(lambda x: ast.literal_eval(x))
all_subjects = mmlu['subject']

# Change [0,0,0] to [0,0,0,1] to simulate choosing a fallback LLM
def expand_labels(label):
    if label == '[0, 0, 0]':
        return [0,0,0,1]
    else:
        return ast.literal_eval(label) + [0]
    
all_labels = mmlu['label'].apply(expand_labels)

In [48]:
ckpt = '/Users/swtsai/Documents/Qwen-CoE/training/checkpoints_combine/epochs=1500_patience=50_batch=16_lr=0.0001/epoch=1474_loss=0.925_tacc=0.415_vacc=0.307.pth'
model = load_router(DEVICE, model_path=ckpt)

with torch.no_grad():
    outputs = model(all_embeds)
    predicted_indices = (outputs > 0.5).float()
    
correct = 0
all_correct = []
fallback_count = 0
zero_count = 0
all_pred = []
for i in range(len(all_labels)):
    pred = np.array(predicted_indices[i].cpu())
    label = np.array(all_labels[i])
    match = np.sum(pred * label)

    all_pred.append(tuple(pred.tolist()))

    if match == 1:
        correct += 1
        all_correct.append(1)
    else:
        all_correct.append(0)

total_examples = len(all_labels)
accuracy = correct / total_examples
print(f'Overall Accuracy: {accuracy:.3f}')

results_dict = {'Subject': all_subjects, 'Accuracy': all_correct, 'Label': all_labels, 'Prediction': all_pred}
results_df = pd.DataFrame(results_dict)
results_df.head()

  pred = np.array(predicted_indices[i].cpu())


Overall Accuracy: 0.124


Unnamed: 0,Subject,Accuracy,Label,Prediction
0,abstract_algebra,0,"[0, 1, 0, 0]","(0.0, 0.0, 0.0, 0.0)"
1,abstract_algebra,1,"[0, 0, 0, 1]","(0.0, 0.0, 0.0, 1.0)"
2,abstract_algebra,0,"[0, 0, 0, 1]","(1.0, 0.0, 1.0, 0.0)"
3,abstract_algebra,0,"[0, 1, 0, 0]","(0.0, 0.0, 0.0, 0.0)"
4,abstract_algebra,0,"[1, 0, 0, 0]","(0.0, 0.0, 0.0, 1.0)"


In [52]:
group = results_df.groupby('Prediction')
group.size()

Prediction
(0.0, 0.0, 0.0, 0.0)    9631
(0.0, 0.0, 0.0, 1.0)    2742
(0.0, 0.0, 1.0, 0.0)     297
(0.0, 1.0, 0.0, 0.0)      92
(0.0, 1.0, 1.0, 0.0)       9
(1.0, 0.0, 0.0, 0.0)      14
(1.0, 0.0, 1.0, 0.0)      69
(1.0, 1.0, 0.0, 0.0)     413
(1.0, 1.0, 1.0, 0.0)     775
dtype: int64

In [53]:
results_df.to_csv('../training/checkpoints_combine/eval_results.csv')