# Precision: Calculating how many of the extracted relationships are valid

In [6]:
!pip install -U nltk

Collecting nltk
  Using cached nltk-3.9.1-py3-none-any.whl.metadata (2.9 kB)
Using cached nltk-3.9.1-py3-none-any.whl (1.5 MB)
Installing collected packages: nltk
  Attempting uninstall: nltk
    Found existing installation: nltk 3.8.1
    Uninstalling nltk-3.8.1:
      Successfully uninstalled nltk-3.8.1
Successfully installed nltk-3.9.1

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [None]:
import fitz,os
from sentence_transformers import SentenceTransformer, util
import nltk
from nltk.stem import PorterStemmer
nltk.download("punkt")
import pandas as pd

In [None]:
# Load a pre-trained sentence transformer
model = SentenceTransformer('all-MiniLM-L6-v2')

# Initialize Python porter stemmer
ps = PorterStemmer()
def lemmatize(sent):
    '''
        Lemmatizes word in sentence
        args: 
            sent: string of words
        return:
            list of lemmatized words
    '''
    return [ps.stem(word) for word in sent.split()]
    
# Reading the pdf and hand annotation files
def read_pdf(pdf_file):
    '''
        Reads text from pdf
        args: 
            pdf_file: string representing name of file to read
        return:
            list of words from pdf
    '''
    start=False
    sentences=[]
    start_idx=0
    with fitz.open(pdf_file) as pdf_file:
        for page_index, page in enumerate(pdf_file):
            text = page.get_text("text").lower()
            text=text.split(". ")
            sentences.extend(text)
                
    return sentences


def read_files(root_dir, hand):
    '''
        Reads all files in a directory and all text in csv
        args:
            root_dir: directory from which to read files 
            hand: csv with hand annotations
        return:
            list of text from hand and documents in root_dir
    '''
    lines=[]
    for files in os.listdir(root_dir):
        if files[-4:] != '.pdf':
            continue
        sentences = read_pdf(f"{root_dir}/{files}")
        lines.extend(sentences)

    # read in hand annotations
    for p in hand.iterrows():
        rel = p[1]['rel']
        subj = p[1]['subj']
        obj = p[1]['obj']
        out=f"{subj} {rel} {obj}" 
        lines.append(out)

    return lines

#computing cosine similarity
def vec(sentences):
    '''
        Computes cosine similarity between 2 sentences
        args:
            sentences: list of 2 sentences
        return:
            similarity score
    '''
    # Encode sentences
    embeddings = model.encode([sentences[0], sentences[1]])
    
    # Compute cosine similarity
    similarity = util.cos_sim(embeddings[0], embeddings[1])
    return similarity.item() # Value close to 1 indicates high similarity
    
#finding if the target string (relation triplet) is in the src (pdf + hand annotation)
def find(target, src):
    '''
        Finds target sentence in src
        args:
            target: sentence to search for
            src: list to search through

        return:
            boolean representing if text was found and sentence most closely aligning with target
    '''
    found=False
    matching_sentence=""
 
    for idx,sentence in enumerate(src):
        pred=" ".join(lemmatize(target))
        test=" ".join(lemmatize(sentence))
        cos = vec([pred,test])
        if pred in test or cos > 0.7:
            if cos >0.65 and cos < 0.7:
                print(f"Got a match for {pred }: {sentence}")
            elif cos <=0.65:
                print(f"Closest match to {pred} was {test}")
            found=True
            st_idx=idx
            matching_sentence=sentence
            return found, matching_sentence, 
            
    return found, matching_sentence


In [None]:
def main(ground_truth_file, pred_files_dir):
    """
        Goes through prediction files and finds how many of the extracted relationships are found in the ground truth
        args:
            ground_truth_file: file containing results to compare predictions to (combination of manual annotations and src documents)
            pred_files: LLM extractions
        return:
            precision per file
    """
    ground_truth = pd.read_csv(f"../Results/{ground_truth_file}")
    pred_files = os.listdir(pred_files_dir)
                            
    sentences = read_files("../Docs",ground_truth)
    
    
    for pred_file in pred_files
        pred_file = f"../Results/{pred_file}"
        preds = pd.read_csv(pred_file) 
        score = 0
        for p in preds.iterrows():
            ref = p[1]['ref']
            rel = p[1]['rel']
            subj = p[1]['subj']
            obj = p[1]['obj']
            out=f"{subj} {rel} {obj}"
            
            found, match = find(out, sentences)
            if found:
                score +=1
            else:
                print("Couldn't find a match for  ", out)
        print(f"Precision for {pred_file} is {score/len(preds)}")
ground_truth_file = "ground_truth.csv"
pred_files_dir = "../Results"
main(ground_truth_file, pred_files_dir)

In [None]:
Precision for ../NewRels_Skip3_PassingInIncrements.csv : 0.8181818181818182
Precision for ../NewRels_Skip4_increments.csv is 0.6785714285714286
Precision for ../Results/NewRels_Skip3_cummulative.csv is 0.8584905660377359
Precision for ../Results/NewRels_Skip2_cummulative.csv is 0.901840490797546
Precision for ../Results/NewRels_Skip2_increments.csv is 0.7616279069767442
Precision for ../Results/Temperature0point2.csv is 0.7981651376146789