In [1]:
import sys
import os

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

In [2]:
from tqdm import tqdm
from pathlib import Path
from src.utils.config_loader import load_config
from src.utils.seed import seed_everything

base_dir = Path(os.getcwd()).parent

config = load_config(base_dir / 'secrets.yaml')

seed_everything(42)

In [3]:
from src.data.preprocessing import create_df

val_df = create_df(base_dir / 'data/my_data/regplans-dev.conllu')

In [4]:
from langchain_openai import AzureChatOpenAI
from langchain_core.messages import (SystemMessage, HumanMessage)

os.environ['OPENAI_API_VERSION'] = config['OPENAI_API_VERSION']
os.environ['AZURE_OPENAI_ENDPOINT'] = config['OPENAI_API_BASE']
os.environ['AZURE_OPENAI_API_KEY'] = config['OPENAI_API_KEY']

llm = AzureChatOpenAI(
    deployment_name=config['OPENAI_DEPLOYMENT_NAME'],
    temperature=0.0
)

In [5]:
import json
import random

def format_examples(example_subset): 
    # Formats the examples into a string for later prompt
    formatted = []
    for i, ex in enumerate(example_subset):
        entity_lines = "\n".join([f"{e['word']} {e['label']}" for e in ex["entities"]])
        formatted.append(f"Example {i+1}:\nText: \"{ex['sentence']}\"\nEntities:\n{entity_lines}\n##\n")
    
    return "\n".join(formatted)

with open(base_dir / 'llm_stuff/prompts/examples.json', 'r') as f:
    example_bank = json.load(f)

ids = [1, 19, 16, 3, 21]

#ids = random.sample(range(1, 26), 3)

examples = [next(ex for ex in example_bank if ex["id"] == id) for id in ids]

formatted_examples = format_examples(examples)

print(formatted_examples)

Example 1:
Text: "Adkomst til BFS1 og BFS2 skal være fra Solfjellveien ."
Entities:
BFS1 B-FELT
BFS2 B-FELT
##

Example 2:
Text: "Parkeringsplassar ( SPP ) Grøntstruktur , jf . PBL § 12-5 , 2 . ledd nr . 3 - Turveg ( GT )"
Entities:
SPP B-FELT
GT B-FELT
##

Example 3:
Text: "Før det vert gjeve mellombels bruksløyve / ferdigattest for ny bueining innanfor felt BKS1 og BFS14 og 15"
Entities:
BKS1 B-FELT
BFS14 B-FELT
og I-FELT
15 I-FELT
##

Example 4:
Text: "Bebyggelsestype Innenfor BKS1-BKS6 og BFS2 skal det oppføres flermannsboliger , kjedeboliger og / eller rekkehus ."
Entities:
BKS1-BKS6 B-FELT
BFS2 B-FELT
##

Example 5:
Text: "Areal brattere enn 1 : 3 , arealer i gul eller rød sone for henholdsvis støy ( T-1442 ) og luftkvalitet ( T-1520 ) ."
Entities:

##



In [6]:
from src.utils.label_mapping_regplans import label_to_id
from collections import defaultdict

all_pred_ids = []
all_true_ids = []
all_results = []

val_df = val_df.iloc[:int(len(val_df) * 0.5)] # Use only half the data for testing

for idx, row in tqdm(val_df.iterrows(), total=len(val_df)):

    sentence = row['full_text']
    tokens = row['words']
    true_labels = row['labels']  

    msg = [
    SystemMessage(
        f"""
        You are an expert in Named Entity Recognition (NER). Your task is to identify named entities that represent field zone names in the given text.
        """
    ),
    HumanMessage(
        f""" 
        The possible named entities are exclusively B-FELT (beginning of a field zone name) and I-FELT (continuation of the same field zone name).

        {formatted_examples}

        Return one line per token, including only tokens that are part of field zone names, each followed by its corresponding label, separated by a space.
                 
        Text: '{sentence}'

        Entities:
        """
    )]

    try:
        response = llm.invoke(msg)

        entities = defaultdict(list) # Word-label pairs

        for line in response.content.splitlines():
            parts = line.strip().split()
            if len(parts) == 2:
                word, label = parts[0], parts[1]
                entities[word].append(label)

        pred_labels = []
        word_counts = defaultdict(int)  # Track occurrences of each word

        for token in tokens:
            if token in entities and word_counts[token] < len(entities[token]):
                pred_labels.append(entities[token][word_counts[token]])  # Get the label in order
                word_counts[token] += 1  # Increment occurrence counter
            else:
                pred_labels.append("O")  # Default to "O" if missing

        # Convert labels to IDs
        pred_ids = []
        for label in pred_labels:
            if label in label_to_id:
                pred_ids.append(label_to_id[label])
            else:
                pred_ids.append(label_to_id.get("O", -1))

        true_ids = [label_to_id[label] for label in true_labels]

        all_pred_ids.extend(pred_ids)
        all_true_ids.extend(true_ids)

        all_results.append({
            'sentence': sentence,
            'tokens': tokens,
            'true_labels': true_labels,
            'predicted_labels': pred_labels,
            'generated_text': response.content
        })   
        
    except Exception as e:
        print(f"Skipping row {idx} due to error: {e}")
        continue   

100%|██████████| 176/176 [17:16<00:00,  5.89s/it]


In [7]:
from llm_stuff.evaluation import evaluate 

metrics = evaluate(all_true_ids, all_pred_ids)

print("Evaluation Metrics on Val Set:")
print(metrics)

final_output = {
    'prompt': str(msg),
    'evaluation_metrics': metrics,
    'results': all_results
}

with open(base_dir / f"llm_stuff/results/{config['OPENAI_DEPLOYMENT_NAME']}_FEWSHOT_5_V2.json", 'w', encoding='utf-8') as f:
    json.dump(final_output, f, indent=4, ensure_ascii=False)

Evaluation Metrics on Val Set:
{'precision': 0.7729340195655823, 'recall': 0.7344198822975159, 'f1': 0.7513071298599243, 'span_acc': 0.8362069129943848, 'classification_report': {'B-FELT': {'precision': 0.8947368421052632, 'recall': 0.8793103448275862, 'f1-score': 0.8869565217391304, 'support': 116.0}, 'I-FELT': {'precision': 0.4375, 'recall': 0.3333333333333333, 'f1-score': 0.3783783783783784, 'support': 21.0}, 'O': {'precision': 0.9865654205607477, 'recall': 0.9906158357771261, 'f1-score': 0.9885864793678666, 'support': 1705.0}, 'accuracy': 0.9761129207383279, 'macro avg': {'precision': 0.772934087555337, 'recall': 0.7344198379793485, 'f1-score': 0.7513071264951251, 'support': 1842.0}, 'weighted avg': {'precision': 0.9745228098481463, 'recall': 0.9761129207383279, 'f1-score': 0.9752295601465242, 'support': 1842.0}}}
