In [1]:
import torch
from tqdm import tqdm
from datasets import load_dataset

assert torch.cuda.is_available(), "CUDA is not available"
device = torch.device("cuda")

## FinBERT-tone
From huggingface [yiyanghkust/finbert-tone](https://huggingface.co/yiyanghkust/finbert-tone)

In [2]:
from transformers import BertTokenizer, BertForSequenceClassification

model_id = 'yiyanghkust/finbert-tone'
tokenizer = BertTokenizer.from_pretrained(model_id)
model = BertForSequenceClassification.from_pretrained(model_id, num_labels=3).to(device)

labels = ["Neutral", "Positive", "Negative"]

In [None]:
def classifier_evaluate(model, testset, labels):
    correct = 0
    prog_bar = tqdm(testset)
    for i, example in enumerate(prog_bar):
        input_ids = torch.tensor(example["input_ids"], device=device)
        attn_mask = torch.tensor(example["attention_mask"], device=device)

        out = model.forward(input_ids=input_ids,
                            attention_mask=attn_mask)
        pred = torch.argmax(out.logits.cpu())
        if example["options"][example["gold_index"]] == labels[pred]:
            correct += 1

        prog_bar.set_description(f"{100 * correct / (i+1):.2f}")
    
    return {
        "accuracy": correct / len(testset)
    }

In [None]:
def finbert_preprocess(example, max_length=512):
    zeroshot: str = example['input'].rsplit("\n\n", maxsplit=1)[-1]
    index = zeroshot.find("Question: what is the sentiment?")
    return tokenizer(zeroshot[:index],
                     truncation=True,
                     padding="max_length",
                     max_length=max_length,
                     return_tensors="pt")

testset_adaptllm = load_dataset("AdaptLLM/finance-tasks", "FPB", split="test").map(finbert_preprocess, batched=False)

Map:   0%|          | 0/970 [00:00<?, ? examples/s]

In [None]:
results = classifier_evaluate(model,
                              testset_adaptllm,
                              labels)
print(results)

59.07: 100%|██████████| 970/970 [00:15<00:00, 63.12it/s]

{'accuracy': 0.5907216494845361}



