# FinGPT Test: Named Entity Recognition (NER)

This notebook demonstrates how to test FinGPT models on the Named Entity Recognition (NER) dataset.

## 1. Install Dependencies

In [None]:
!pip install transformers==4.32.0 peft==0.5.0 datasets accelerate bitsandbytes sentencepiece tqdm scikit-learn pandas matplotlib seaborn

## 2. Import Libraries

In [9]:
import os
import csv
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel, PeftConfig
from tqdm import tqdm
import pandas as pd

## 3. Login to Hugging Face

In [None]:
from huggingface_hub import login
login(token="token")

## 4. Download NER dataset

In [None]:
!wget -O /content/test.parquet https://huggingface.co/datasets/FinGPT/fingpt-ner-cls/resolve/main/data/test-00000-of-00001-71355aae60cb2b0b.parquet

In [None]:
df = pd.read_parquet("test.parquet")
df.to_csv("ner.csv", index=False)

## 5. Load Models

In [None]:
import os
import csv
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel, PeftConfig
from tqdm import tqdm
import pandas as pd

batch_size = 4
max_length = 512

def parse_model_name(base_model, from_remote=False):
    model_map = {
        'chatglm2': 'THUDM/chatglm2-6b',
        'llama2': 'meta-llama/Llama-2-7b-hf',
        'llama2-13b': 'meta-llama/Llama-2-13b-hf',
        'llama2-13b-nr': 'NousResearch/Llama-2-13b-hf',
        'baichuan': 'baichuan-inc/Baichuan-7B',
        'falcon': 'tiiuae/falcon-7b',
        'internlm': 'internlm/internlm-7b',
        'qwen': 'Qwen/Qwen-7B',
        'mpt': 'mosaicml/mpt-7b',
        'bloom': 'bigscience/bloom-7b1',
    }
    if base_model not in model_map:
        raise ValueError(f"Unknown base model: {base_model}")
    return model_map[base_model]

def load_model_and_tokenizer(base_model, peft_model, from_remote=True):
    if from_remote:
        model_name = parse_model_name(base_model)
    else:
        model_name = '../' + parse_model_name(base_model)

    model = AutoModelForCausalLM.from_pretrained(
        model_name, trust_remote_code=True,
        # load_in_8bit=True
        device_map="auto",
        # fp16=True
    )
    model.model_parallel = True
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

    # tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = "left"
    if base_model == 'qwen':
        tokenizer.eos_token_id = tokenizer.convert_tokens_to_ids('<|endoftext|>')
        tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids('<|extra_0|>')
    if not tokenizer.pad_token or tokenizer.pad_token_id == tokenizer.eos_token_id:
        tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        model.resize_token_embeddings(len(tokenizer))

    print(f'pad: {tokenizer.pad_token_id}, eos: {tokenizer.eos_token_id}')

    model = PeftModel.from_pretrained(model, peft_model)
    model = model.eval()

    return model, tokenizer

base_model = 'llama2'
peft_model = 'FinGPT/fingpt-mt_llama2-7b_lora'
model, tokenizer = load_model_and_tokenizer(base_model, peft_model, from_remote=True)

## 6. Define NER Function

In [None]:
def get_entity_response(context, instruction):
    prompt = f"Instruction: {instruction}\nInput: {context}\nAnswer: "

    tokens = tokenizer(prompt, return_tensors='pt', max_length=512, truncation=True, padding=True)
    tokens = {k: v.to(model.device) for k, v in tokens.items()}

    with torch.no_grad():
        outputs = model.generate(**tokens, max_length=512, eos_token_id=tokenizer.eos_token_id)
    decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Split on "Answer: "
    if "Answer: " in decoded:
        answer = decoded.split("Answer: ")[1].strip()
    else:
        answer = decoded

    return answer

input_file = 'ner.csv'
output_file = 'fingpt_ner.csv'

## 7. Process the Dataset

In [None]:
with open(input_file, 'r', encoding='utf-8') as infile:
    reader = csv.DictReader(infile)
    all_rows = list(reader)
    fieldnames = reader.fieldnames
    if 'model answer' not in fieldnames:
        fieldnames.append('model answer')

with open(output_file, 'w', newline='', encoding='utf-8') as outfile:
    writer = csv.DictWriter(outfile, fieldnames=fieldnames)
    writer.writeheader()

    for row in tqdm(all_rows, desc="Processing rows"):
        context = row.get('input', '')
        instruction = row.get('instruction', '')

        try:
            row['model answer'] = get_entity_response(context, instruction)
        except Exception as e:
            print(f"Error: {e}")
            row['model answer'] = 'ERROR'

        writer.writerow(row)

## 8. Evaluation

We then manually clean the file in Google Sheets to clean cases like "a person", "an organization", "industry, organization" to be the standard "location, person, organization" options.

We run the following code to label each correct answer with 1 in a scoring column and 0 for each incorrect answer in a scoring column.

In [None]:
"""
function checkAnswers() {
  // Open the active spreadsheet
  const sheet = SpreadsheetApp.getActiveSpreadsheet().getActiveSheet();

  const dataRange = sheet.getDataRange();
  const data = dataRange.getValues();

  const headers = data[0];
  const answersIndex = headers.indexOf('answer');
  const modelAnswersIndex = headers.indexOf('model answer');

  if (answersIndex === -1 || modelAnswersIndex === -1) {
    SpreadsheetApp.getUi().alert('Error: "answers" or "model answers" column not found.');
    return;
  }

  const scoringIndex = headers.indexOf('scoring');
  if (scoringIndex === -1) {
    sheet.getRange(1, headers.length + 1).setValue('scoring');
  }

  for (let i = 1; i < data.length; i++) {
    const answer = data[i][answersIndex].toLowerCase();
    const modelAnswer = data[i][modelAnswersIndex].toLowerCase();

    const scoringValue = (answer === modelAnswer) ? '1' : '0';
    sheet.getRange(i + 1, headers.length + 1).setValue(scoringValue);

    if (scoringValue === '0') {
      sheet.getRange(i + 1, headers.length + 1).setBackground('yellow');
    } else {
      sheet.getRange(i + 1, headers.length + 1).setBackground(null);
    }
  }

  SpreadsheetApp.getUi().alert('Scoring completed!');
}
"""

We then compute the average of the scoring column to get the accuracy. This works since the scoring can only be 0 or 1.

To get the Weighted F1 Score, I run the following Google Script.

In [None]:
"""
function computeWeightedF1() {
  const sheet = SpreadsheetApp.getActiveSpreadsheet().getActiveSheet();
  const dataRange = sheet.getDataRange();
  const data = dataRange.getValues();

  const headers = data[0];
  const answerIndex = headers.indexOf('answer');
  const modelAnswerIndex = headers.indexOf('model answer');

  if (answerIndex === -1 || modelAnswerIndex === -1) {
    SpreadsheetApp.getUi().alert('Error: "answer" or "model answer" column not found.');
    return;
  }

  // Holds stats for each class.
  // true positives (TP), false positives (FP), false negatives (FN), and the count (number of gold instances)
  const stats = {};

  for (let i = 1; i < data.length; i++) {
    const trueLabel = data[i][answerIndex];
    const predLabel = data[i][modelAnswerIndex];

    if (trueLabel === "" || trueLabel === undefined) continue;

    if (!stats[trueLabel]) {
      stats[trueLabel] = { TP: 0, FP: 0, FN: 0, count: 0 };
    }

    stats[trueLabel].count++;

    if (!stats[predLabel]) {
      stats[predLabel] = { TP: 0, FP: 0, FN: 0, count: 0 };
    }

    if (trueLabel === predLabel) {
      stats[trueLabel].TP++;
    } else {
      stats[trueLabel].FN++;
      stats[predLabel].FP++;
    }
  }

  let weightedF1Sum = 0;
  let weightSum = 0;

  for (let label in stats) {
    if (stats[label].count > 0) {
      const TP = stats[label].TP;
      const FP = stats[label].FP;
      const FN = stats[label].FN;

      const precision = (TP + FP) > 0 ? TP / (TP + FP) : 0;
      const recall = (TP + FN) > 0 ? TP / (TP + FN) : 0;
      const f1 = (precision + recall) > 0 ? (2 * precision * recall) / (precision + recall) : 0;

      weightedF1Sum += stats[label].count * f1;
      weightSum += stats[label].count;
    }
  }

  const weightedF1 = weightSum > 0 ? weightedF1Sum / weightSum : 0;

  SpreadsheetApp.getUi().alert('Weighted F1 Score: ' + weightedF1);
}
"""