In [None]:

import re
import pickle
from collections import Counter

# Set up OpenAI API key
from seqeval.metrics import classification_report, f1_score, precision_score, recall_score

import random

# Set the seed for reproducibility
random.seed(42)


In [None]:
entities_list=["Location","Date","Person","Organization","Event"]

In [None]:
def parse_file(file_path):
  with open(file_path,"r") as f:
    tokens,labels = [],[]
    t,l = [], []
    for line in f.readlines():
      tmp = line.strip().split()
      if len(tmp) == 0:
        tokens.append(t)
        labels.append(l)
        t, l = [], []
      else:
        t.append(tmp[0])
        l.append(tmp[1])
    if len(t) > 0:
      tokens.append(t)
      labels.append(l)
    data = tokens,labels
    return data

def get_news_data_sets():
  train_data= parse_file("everest-ner/EverestNER-train-bio.txt")
  test_data= parse_file("everest-ner/EverestNER-test-bio.txt")
  return train_data,test_data

def get_tweets_data_sets():
  train_data = parse_file("DanfeNER/DanfeNER-train-bio.txt")
  test_data = parse_file("DanfeNER/DanfeNER-test-bio.txt")
  return train_data,test_data



In [None]:
news_train,news_test=get_news_data_sets()
news_train_sentences, news_train_labels = news_train
news_test_sentences, news_test_labels = news_test
len(news_train_sentences),len(news_test_sentences)

In [None]:
def convert_to_bio(entity, prediction,original_sentence):
    tokens = prediction.split()
    original_tokens=original_sentence.split()
    bio_labels = []
    inside_entity = False  # Tracks if we are inside an entity

    for token in tokens:
        if '@@' in token and '##' in token:  # Entire entity in one token
            bio_labels.append(f"B-{entity}")
            inside_entity = False
        elif '@@' in token:  # Entity begins in this token
            bio_labels.append(f"B-{entity}")
            inside_entity = True
        elif '##' in token:  # Entity ends in this token
            bio_labels.append(f"I-{entity}")
            inside_entity = False
        else:
            if inside_entity:  # Continuation of the entity
                bio_labels.append(f"I-{entity}")
            else:  # Outside of any entity
                bio_labels.append("O")
    if len(bio_labels)==0:
        bio_labels=["O"]*len(original_tokens)

    return bio_labels

In [None]:
def merge_bio_labels_with_continuation_priority(lists, tokens):
    """
    Merge BIO labels from multiple lists with continuation priority.
    The priority is given to entities with the longest valid continuation.

    Args:
        lists (list of lists): BIO-tagged lists to merge.
        tokens (list): The original tokens for length reference.

    Returns:
        list: A merged BIO-tagged list with the same length as the tokens.
    """
    # Determine the length of the tokens
    sentence_length = len(tokens)

    # Initialize a merged list with "O"
    merged_list = ["O"] * sentence_length

    # Iterate over all positions in the sentence
    for i in range(sentence_length):
        # Collect all entities at the current position across all lists
        entities_at_position = [
            lst[i] if i < len(lst) else "O"
            for lst in lists
        ]

        # Filter out "O" labels
        entities_at_position = [entity for entity in entities_at_position if entity != "O"]

        if entities_at_position:
            # If only one entity is present, choose it
            if len(entities_at_position) == 1:
                merged_list[i] = entities_at_position[0]
            else:
                # Handle conflicting entities
                continuation_scores = {}
                for entity in entities_at_position:
                    if entity.startswith("B-"):
                        entity_type = entity[2:]
                        # Calculate continuation length
                        continuation_length = 0
                        for lst in lists:
                            pos = i
                            while pos + 1 < len(lst) and lst[pos + 1] == f"I-{entity_type}":
                                continuation_length += 1
                                pos += 1
                        continuation_scores[entity] = continuation_length

                # Choose the entity with the longest continuation
                if continuation_scores:
                    best_entity = max(continuation_scores, key=continuation_scores.get)
                    merged_list[i] = best_entity
                else:
                    # Default to the first valid entity if no continuation
                    merged_list[i] = entities_at_position[0]

    return merged_list

In [None]:
def align_sentences(S, T):
    """
    Aligns the predicted sentence T to match the reference sentence S,
    ensuring tokens in T match S while preserving the original @@ and ## markers.

    Parameters:
    S (str): The reference sentence.
    T (str): The predicted sentence.

    Returns:
    str: The aligned version of T, matching the structure of S.
    """
    s_tokens = S.split()
    t_tokens = T.split()

    aligned_t_tokens = []
    t_index = 0

    for s_token in s_tokens:
        # Track reconstructed token from T
        reconstructed_token = ""
        while t_index < len(t_tokens):
            t_token = t_tokens[t_index]
            
            # Strip @@ and ## for comparison
            stripped_token = t_token.replace("@@", "").replace("##", "")

            # Combine tokens from T to match S
            if not reconstructed_token:
                reconstructed_token = stripped_token
            else:
                reconstructed_token += stripped_token

            t_index += 1

            # Check if the reconstructed token matches the current S token
            if reconstructed_token == s_token:
                # If token was reconstructed, wrap it with @@ and ##
                if len(reconstructed_token) > len(t_token):
                    aligned_t_tokens.append(f"@@{reconstructed_token}##")
                else:
                    aligned_t_tokens.append(t_token)
                break
        else:
            # If T tokens are exhausted without a match, add S token as is
            aligned_t_tokens.append(s_token)

    # Handle any remaining tokens in T
    while t_index < len(t_tokens):
        aligned_t_tokens.append(t_tokens[t_index])
        t_index += 1

    # Join the aligned tokens into a single string
    return ' '.join(aligned_t_tokens)




In [None]:
def post_process_output(output):
    prediction = output
    
    # Check if "Output:" is in the prediction and process accordingly
    if "Output:" in prediction:
        prediction = prediction.split("Output:")[1].strip()

    if "नतिजा:" in prediction:
        prediction = prediction.split("नतिजा:")[1].strip()

    if "वाक्य:" in prediction:
        prediction = prediction.split("वाक्य:")[1].strip()
    
    # Extract portion before "Note" if it exists
    if "Note" in prediction:
        prediction = prediction.split("Note", 1)[0].strip()
    
    # Extract up to the first occurrence of "।"
    if "।" in prediction:
        prediction = prediction.split("।", 1)[0] + "।"
    
    # Return the processed prediction
    return prediction

In [None]:
with open("output_path/train_datasets_with_tagging.pkl", "rb") as file:  # "rb" stands for read binary
    train_datasets_with_tagging = pickle.load(file)

with open("output_path/test_datasets_with_tagging.pkl", "rb") as file:  # "rb" stands for read binary
    test_datasets_with_tagging = pickle.load(file)


In [None]:
ground_truth_separated={}
for sentence in test_datasets_with_tagging:
    # if sentence == "\' माइती टाढा ।":
        gd=test_datasets_with_tagging[sentence][1]
        a={}
        for ent in entities_list:
            temp=["O"]*len(gd)
            pos=[index for index, label in enumerate(gd) if ent in label]
            if len(pos)>0:
                for i in pos:
                    temp[i]=gd[i]
            # if len(temp)!=len(gd):
                # print("I")
            a[ent]=temp
        ground_truth_separated[sentence]=a

In [None]:
output_file_name="pickled_file.pkl"

with open("/home/sneupane/NER/Flairs_paper/output/"+output_file_name , "rb") as file:  # "rb" stands for read binary
    pickled_file = pickle.load(file)



In [None]:
def align_sentences(S, T):
    """
    Aligns the predicted sentence T to match the reference sentence S
    while preserving special tokens ('@@' and '##') in T.

    Parameters:
    S (str): The reference sentence.
    T (str): The predicted sentence.

    Returns:
    str: The aligned version of T, matching the structure of S.
    """
    # Split sentences into tokens
    s_tokens = S.split()
    t_tokens = T.split()

    aligned_t_tokens = []
    t_index = 0  # Pointer for T tokens

    for s_token in s_tokens:
        if t_index < len(t_tokens):
            t_token = t_tokens[t_index]

            # If the current token in T contains special markers, preserve it.
            if '@@' in t_token or '##' in t_token:
                aligned_t_tokens.append(t_token)
                t_index += 1  # Move to the next token in T
            else:
                # Align tokens from T to match S
                if t_token == s_token:
                    aligned_t_tokens.append(t_token)
                else:
                    aligned_t_tokens.append(s_token)
                t_index += 1
        else:
            # If T is shorter than S, pad with tokens from S
            aligned_t_tokens.append(s_token)

    # Handle any remaining tokens in T after exhausting S
    while t_index < len(t_tokens):
        aligned_t_tokens.append(t_tokens[t_index])
        t_index += 1

    # Join aligned tokens back into a sentence
    return ' '.join(aligned_t_tokens)

In [None]:


def calculate_f1_merge(output_file_name):
    y_true=[]
    y_pred=[]
    y_entity_pred=[]
    for sentence in output_file_name:
        true_label=test_datasets_with_tagging[sentence][1]
        all_predictions=[]
        for entity in entities_list:            
            prediction=(output_file_name[sentence][entity])        
            prediction=post_process_output(prediction)
            all_predictions.append(convert_to_bio(entity,prediction,sentence))
        merged_predictions=merge_bio_labels_with_continuation_priority(all_predictions,true_label)
        
        true_label=[item.upper() for item in true_label]
        merged_predictions=[item.upper() for item in merged_predictions]
        y_true.append(true_label)
        y_pred.append(merged_predictions)
    # print(y_true)
    # print(y_pred)
    print(classification_report(y_true, y_pred))
    print("\nPrecision:", precision_score(y_true, y_pred))
    print("Recall:", recall_score(y_true, y_pred))
    print("F1-Score:", f1_score(y_true, y_pred))
    # return y_true,y_pred

In [None]:


def calculate_f1_individual(output_file_name):
    y_true1=[]
    y_pred1=[]
    count=0
    error_sentences={}
    for sentence in output_file_name:
        true_label=ground_truth_separated[sentence]

        for entity in entities_list:


            prediction=(output_file_name[sentence][entity])
            prediction=post_process_output(prediction)

            predicted_sentences=convert_to_bio(entity,prediction,sentence)
            if len(predicted_sentences)==len(true_label[entity]):

                y_true1.append(true_label[entity])
                y_pred1.append(predicted_sentences)
            else:
                # print(len(predicted_sentences),len(true_label[entity]))
                # print(entity,true_label[entity],predicted_sentences)
                aligned_T = align_sentences(sentence, prediction)
                predicted_sentences=convert_to_bio(entity,aligned_T,sentence)
                if len(predicted_sentences)==len(true_label[entity]):
                    y_true1.append(true_label[entity])
                    y_pred1.append(predicted_sentences)
                else:
                    error_sentences[sentence]=[entity,prediction]
                    count+=1
    print(classification_report(y_true1, y_pred1))
    print("\nPrecision:", precision_score(y_true1, y_pred1))
    print("Recall:", recall_score(y_true1, y_pred1))
    print("F1-Score:", f1_score(y_true1, y_pred1))
    print("Error sentences : ", len(error_sentences))
    return error_sentences