In [None]:
import torch
import random
from pathlib import Path
from huggingface_hub import constants as hub_c
from tqdm import tqdm
from transformers import AutoTokenizer

from evals import load_eval_dataset, get_tensors
from FinMoE import FinMoE
from utils import get_dataset_args

assert torch.cuda.is_available(), "CUDA not available"
device = torch.device("cuda")

seed = 42
torch.manual_seed(seed)

model_id = "meta-llama/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

args = get_dataset_args(tokenizer, Path(hub_c.HF_HUB_CACHE))

In [2]:
ckpt_path = Path(r"D:/models/FinMoE-final-top3-fast/checkpoint-3590") # 3590  e256
finMoE_model = FinMoE.load_pretrained(ckpt_path).to(device).eval()

In [None]:
def predict(sentence: str, dataset_id: str):
    prompt = args.prompt_templates[dataset_id].format(sentence)
    tokenized = tokenizer(prompt, truncation=False, return_tensors="pt")

    token_opts = args.token_opts[dataset_id]

    output = finMoE_model.forward(tokenized["input_ids"].to(device),
                                attention_mask=tokenized["attention_mask"].to(device))

    out_token = torch.argmax(output.logits[0, -1, token_opts].cpu(), dim=-1)
    return tokenizer.decode([token_opts[out_token.item()]]).strip()

def display(sample):
    text = tokenizer.decode(sample["input_ids"], skip_special_tokens=True)
    print(f"Sample: {text}\nPrediction: {sample["prediction"]}\nLabel: {sample["label"]}")

    ## for topics
    # text = text.split("0 - Analyst Update", maxsplit=1)[0]
    # pred = args.topics[int(sample["prediction"])]
    # label = args.topics[int(sample["label"])]
    # print(f"Sample: {text}\nPrediction: {pred}\nLabel: {label}")

In [None]:
dataset_id = "Topics"

token_opts = args.token_opts[dataset_id]
testset = load_eval_dataset(tokenizer, dataset_id, args)

correct_dataset = [] # (Yes, No)
wrong_dataset = []

for sample in tqdm(testset):
    input_ids, attn_mask = get_tensors(sample)
    gen_idx = attn_mask.sum(dim=1).long() - 1

    input_ids = input_ids.to(device)
    attn_mask = attn_mask.to(device)

    finMoE_model.expert.disable_adapter()
    gate_scores = finMoE_model.gate.forward(input_ids, attn_mask)
    expert_idx = torch.argmax(gate_scores, dim=-1).item()

    finMoE_model.expert.set_adapter(f"{expert_idx}")
    output = finMoE_model.expert.forward(input_ids, attn_mask)

    gen_logits = output.logits[0, gen_idx, token_opts].cpu()
    local_argmax = torch.argmax(gen_logits, dim=-1).item()
    gen_token = token_opts[local_argmax]

    sample["prediction"] = tokenizer.decode(gen_token).strip(" ")
    sample["label"] = sample["options"][sample["gold_index"]]
    if len(sample["input_ids"]) < 64 or True:
        if sample["prediction"] == sample["label"]:
            correct_dataset.append(sample)
        else:
            wrong_dataset.append(sample)

In [None]:
correct_sample = random.sample(correct_dataset, k=1)
wrong_sample = random.sample(wrong_dataset, k=1)

display(correct_sample[0])
display(wrong_sample[0])