In [None]:
import sys
import os

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

In [None]:
from tqdm import tqdm
from pathlib import Path
from src.utils.config_loader import load_config
from src.utils.seed import seed_everything

base_dir = Path(os.getcwd()).parent

config = load_config(base_dir / 'secrets.yaml')

seed_everything(42)

In [None]:
from src.data.preprocessing import create_df

val_df = create_df(base_dir / 'data/my_data/regplans-dev.conllu')

# I should turn shuffle off

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch

model_name = 'google/flan-t5-base'

tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name, device_map="auto")

#tokenizer = AutoTokenizer.from_pretrained(model_name)
#model = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map='auto', torch_dtype=torch.float16)

In [None]:
import json
import random

def format_examples(example_subset): 
    # Formats the examples into a string for later prompt
    formatted = []
    for i, ex in enumerate(example_subset):
        entity_lines = "\n".join([f"{e['word']} {e['label']}" for e in ex["entities"]])
        formatted.append(f"Example {i+1}:\nText: \"{ex['sentence']}\"\nEntities:\n{entity_lines}\n##\n")
    
    return "\n".join(formatted)

with open(base_dir / 'llm_stuff/prompts/examples.json', 'r') as f:
    example_bank = json.load(f)

ids = [1]

examples = [next(ex for ex in example_bank if ex["id"] == id) for id in ids]

formatted_examples = format_examples(examples)

print(formatted_examples)

In [None]:
from src.utils.label_mapping_regplans import label_to_id
from collections import defaultdict

all_pred_ids = []
all_true_ids = []
all_results = []

val_df = val_df.iloc[:int(len(val_df) * 0.01)] # Use only half the data for testing

for idx, row in tqdm(val_df.iterrows(), total=len(val_df)):

    sentence = row['full_text']
    tokens = row['words']
    true_labels = row['labels']  

    input_text = f"""
        You are an expert in Named Entity Recognition (NER). Your task is to identify named entities that represent field zone names in the given text.

        The possible named entities are exclusively B-FELT (beginning of a field zone name) and I-FELT (continuation of the same field zone name).

        {formatted_examples}

        Return one line per token, including only tokens that are part of field zone names, each followed by its corresponding label, separated by a space.
                 
        Text: '{sentence}'

        Entities:
    """

    input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")

    outputs = model.generate(input_ids)

    entities = defaultdict(list) # Word-label pairs

    for line in tokenizer.decode(outputs[0]).splitlines():
        parts = line.strip().split()
        if len(parts) == 2:
            word, label = parts[0], parts[1]
            entities[word].append(label)

    pred_labels = []
    word_counts = defaultdict(int)  # Track occurrences of each word

    for token in tokens:
        if token in entities and word_counts[token] < len(entities[token]):
            pred_labels.append(entities[token][word_counts[token]])  # Get the label in order
            word_counts[token] += 1  # Increment occurrence counter
        else:
            pred_labels.append("O")  # Default to "O" if missing

    # Convert labels to IDs
    pred_ids = []
    for label in pred_labels:
        if label in label_to_id:
            pred_ids.append(label_to_id[label])
        else:
            pred_ids.append(label_to_id.get("O", -1))

    true_ids = [label_to_id[label] for label in true_labels]

    all_pred_ids.extend(pred_ids)
    all_true_ids.extend(true_ids)

    all_results.append({
        'sentence': sentence,
        'tokens': tokens,
        'true_labels': true_labels,
        'predicted_labels': pred_labels,
        'generated_text': tokenizer.decode(outputs[0]),
    })   

In [None]:
from llm_stuff.evaluation import evaluate 

metrics = evaluate(all_true_ids, all_pred_ids)

print("Evaluation Metrics on Val Set:")
print(metrics)

final_output = {
    'prompt': str(input_text),
    'evaluation_metrics': metrics,
    'results': all_results
}

with open(base_dir / f"llm_stuff/results/{model_name}_ONESHOT.json", 'w', encoding='utf-8') as f:
    json.dump(final_output, f, indent=4, ensure_ascii=False)