In [257]:
# Imports
import os
from groq import Groq
import json
import jinja2
import re
from typing import Dict, List, Tuple

# **Prompt engineering with LLMs**
---

Currently in this notebook, we have tested only models using the Groq API.
Models from using the OpenAI API will be added soon.

In [258]:
# Set up the environment variable for the API key
client = Groq(
    api_key=os.environ.get("GROQ_API_KEY"),
)

Here we have:
- paths to the prompt template and the annotation files folder


- the list of types of prompts (zero-shot, one-shot, few-shot)


- the list of models tested


- the list of entities in which we are interested in

**MODIFY THE LIST OF PROMPTS AND MODELS TO YOUR LIKING**

In [259]:
# Folder with the ground-truth texts
ANNOTATIONS_FOLDER = "../annotations/"

# Folder with the prompt templates
PROMPT_PATH = "../prompt_templates/"

# List of prompt templates
LIST_PROMPTS = [
    "zero_shot_prompt_template",
    "one_shot_prompt_template",
    "few_shot_prompt_template"
]

# List of models to test
LIST_MODELS = [
    # "gemma2-9b-it",
    # "mistral-saba-24b",
    "llama-3.3-70b-versatile",
    "meta-llama/llama-4-maverick-17b-128e-instruct",
    "deepseek-r1-distill-llama-70b"
    # "deepseek-r1-distill-qwen-32b" # has been decommissioned
    ]

# List of entities to tag (by the llms) and then extract
TAGS = [
    "MOL", "SOFTNAME", "SOFTVERS", "STIME", "TEMP", "FFM"
]

Firstly, we need a helper function to extract certain information from the ground-truth data:
- The input text that needs to be annotated (`input_text`)


- The manually-found entities (`ground_truth_entities`)

In [260]:
# Process one JSON file to extract the ground truth entities and the input text
def process_json_file(json_file: str) -> tuple:
    with open(json_file, "r") as f:
        data = json.load(f)

    # Extract the input text
    annotation_entry = data["annotations"][0]
    input_text = annotation_entry[0]
    ground_truth_entities = annotation_entry[1]["entities"]

    return input_text, ground_truth_entities

Next, we will define several helper functions to assist with the annotation workflow:

1. **Render the prompt template**  

   We use a Jinja2 template to dynamically inject the text that needs to be annotated. This allows for flexible and reusable prompt formatting.

2. **Interact with the LLM using the template**  

   This function handles communication with the language model using the rendered prompt. It is currently tailored for the Groq API, though the structure may vary if you use a different API.

3. **Save the LLM response to a JSON file**  

   To maintain a record of the process, we save the model's response along with metadata, including the model used, the prompt sent, and the annotated output (in XML format with entity annotations).


In [261]:
def load_and_render_prompt(template_path: str, text_to_annotate: str) -> str:
    """
    Load a Jinja2 template from a file and render it with the provided text to annotate.
    
    Args:
        template_path (str): Path to the template file.
        text_to_annotate (str): Text to be annotated.
    
    Returns:
        str: Rendered prompt string.
    """
    with open(template_path, "r") as f:
        template_content = f.read()
    template = jinja2.Template(template_content)
    return template.render(text_to_annotate=text_to_annotate)


def chat_with_template(prompt:str, template_path: str, model:str, text_to_annotate:str) -> str:
    """
    Chat with the Groq API using a template and a model.
    Args:
        template_path (str): Path to the template file.
        prompt (str): Rendered prompt.
        model (str): Model to use for the chat.
        text_to_annotate (str): Text to be annotated.
    Returns:
        str: Response from the chat completion.
    """

    prompt = load_and_render_prompt(template_path, text_to_annotate)

    chat_completion = client.chat.completions.create(
        messages=[{"role": "user", "content": prompt}],
        model=model,
    )
    return chat_completion.choices[0].message.content


def save_response_as_json(response_text:str, output_path:str) -> None:
    """
    Takes the response text from the AI and saves it as a JSON file.
    Args:
        response_text (str): The response text to save.
        output_path (str): Path to the output JSON file.
    """
    response_text
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(response_text, f, ensure_ascii=False, indent=2)

Before we start running annotations, we need to set up a directory structure to organize the outputs based on prompt types and models used.

The following code will:
- Create a root folder to store all LLM annotations.


- For each prompt template in `LIST_PROMPTS`, create a subfolder.


- Within each prompt folder, create additional subfolders for each model in `LIST_MODELS`.

In [262]:
# Make repository for LLM annotations
ouput_llm_annotation_folder = "../output_llm_annotations/"
if not os.path.exists(ouput_llm_annotation_folder):
    os.makedirs(ouput_llm_annotation_folder)

# Create folder for each prompt type
# Then within that, create a folder for each model
for prompt in LIST_PROMPTS:
    prompt_name = os.path.basename(prompt)
    output_prompt_folder = os.path.join(ouput_llm_annotation_folder, prompt_name)
    if not os.path.exists(output_prompt_folder):
        os.makedirs(output_prompt_folder)

    for model in LIST_MODELS:
        output_model_folder = os.path.join(output_prompt_folder, model)
        if not os.path.exists(output_model_folder):
            os.makedirs(output_model_folder)

## **Run LLM annotations**
---

We will now test our LLM annotation pipeline on a subset of input texts. Specifically, we will:

- Select the first 10 input files from the annotations folder.

    - For each file, we will apply:

        - Each prompt template in `LIST_PROMPTS`

        - Each language model in `LIST_MODELS`

    - Save the LLM's annotated response as a JSON file in a their designated directory: `../output_llm_annotations/{prompt_name}/{model}/{filename}`



We can modify the amount of files that are annotated. To give an idea, for two models, three types of prompts, and 10 texts to annotate, this takes **~ 7mins**

In [263]:
# Use 10 input texts from the annotation folder
number_texts = 0

for filename in os.listdir(ANNOTATIONS_FOLDER): # Loop through the files in the annotations folder
    if number_texts >= 10:
        break

    if filename.endswith(".json") and filename.count("_") == 1:
        number_texts += 1

        print(f"\nProcessing file {number_texts}: {filename} ==============")
        input_text, _ = process_json_file(os.path.join(ANNOTATIONS_FOLDER, filename))

        for prompt in LIST_PROMPTS: # Testing each type of prompt
            print(f"\nFile {number_texts} - Testing prompt: {prompt} -------\n")

            prompt_name = os.path.basename(prompt)
            output_prompt_folder = os.path.join(ouput_llm_annotation_folder, prompt_name)

            for model in LIST_MODELS: # Testing each model
                print(f"File {number_texts} - Testing model: {model}")

                output_model_folder = os.path.join(output_prompt_folder, model)

                response = chat_with_template(
                    prompt=prompt,
                    template_path=os.path.join(PROMPT_PATH, f"{prompt}.txt"),
                    model=model,
                    text_to_annotate=input_text
                )

                # Save the response as a JSON file
                output_path_for_json = os.path.join(output_prompt_folder, model, filename)
                data = {
                    "model": model,
                    "text_to_annotate": input_text,
                    "response": response
                }
                save_response_as_json(data, output_path_for_json)



File 1 - Testing prompt: zero_shot_prompt_template -------

File 1 - Testing model: llama-3.3-70b-versatile
File 1 - Testing model: meta-llama/llama-4-maverick-17b-128e-instruct
File 1 - Testing model: deepseek-r1-distill-llama-70b

File 1 - Testing prompt: one_shot_prompt_template -------

File 1 - Testing model: llama-3.3-70b-versatile
File 1 - Testing model: meta-llama/llama-4-maverick-17b-128e-instruct
File 1 - Testing model: deepseek-r1-distill-llama-70b

File 1 - Testing prompt: few_shot_prompt_template -------

File 1 - Testing model: llama-3.3-70b-versatile
File 1 - Testing model: meta-llama/llama-4-maverick-17b-128e-instruct
File 1 - Testing model: deepseek-r1-distill-llama-70b


File 2 - Testing prompt: zero_shot_prompt_template -------

File 2 - Testing model: llama-3.3-70b-versatile
File 2 - Testing model: meta-llama/llama-4-maverick-17b-128e-instruct
File 2 - Testing model: deepseek-r1-distill-llama-70b

File 2 - Testing prompt: one_shot_prompt_template -------

File 2 -

## **Response quality control**
---

Before analyzing the LLM responses, we implement a few helper functions to perform a basic quality check. This check helps ensure that the outputs meet a minimum standard before moving forward with evaluation or further processing.

For now, our quality control focuses on **verifying that the LLM response includes the original input text**. This simple check helps catch hallucinations or unrelated outputs from the model. Note that the comparison is **case-insensitive**.

While this is a strict baseline, it provides a quick filter for obviously flawed responses.

In the future, we plan to implement more nuanced validation criteria, such as  ensuring that entities annotated by the LLM actually appear in the original input text.

This will allow us to be more flexible while still maintaining meaningful quality standards.


To ensure the integrity of the LLM annotations, we need to verify that the output text closely matches the input text we provided.

1. **Strip annotation tags**


Since the LLM response includes XML-like tags (e.g., `<TAG LABEL>` and `</TAG LABEL>`), we first remove these tags to isolate the raw text. This allows for a fair comparison between the original input and the annotated output.

2. **Compare input and output texts**


Once tags are stripped, we compare the cleaned LLM output to the original input text. This helps us confirm that the model has not introduced hallucinated content or omitted any part of the input. Only the responses that **pass this check** are retained for further analysis or evaluation.

In [264]:
# Function to strip tags from the annotated text
def strip_tags(text:str, tags=TAGS) -> str:
    """
    Removes the custom tags from the annotated text.
    """
    for tag in tags:
        text = re.sub(f"</?{re.escape(tag)}>", "", text)
    return text.strip()

# Function to compare the annotated text to the original input
def compare_annotated_to_original(original: str, annotated: str) -> bool:
    """
    Compares tag-stripped annotated text to the original input in lowercase.
    Returns True if they match exactly (ignoring case), False otherwise.
    """
    stripped = strip_tags(annotated).strip().lower()
    original = original.strip().lower()

    return stripped == original

def load_annotation_result(file_path:str) -> tuple:
    """
    Load the annotation result from a JSON file.
    Args:
        file_path (str): Path to the JSON file.
    Returns:
        tuple: Original text and annotated output.
    """
    try:
        with open(file_path, "r", encoding="utf-8") as f:
            data = json.load(f)
        original_text = data.get("text_to_annotate")
        annotated_output = data.get("response")

        if original_text is None or annotated_output is None:
            raise ValueError("Missing required fields in JSON: 'text_to_annotate' or 'response'.")

        return original_text, annotated_output

    except (json.JSONDecodeError, FileNotFoundError, ValueError) as e:
        print(f"Error loading file {file_path}:", e)
        return None, None


>
>Later on: Qualifying rejected annotations
>
>Although not implemented yet, we aim to better understand and categorize the annotations that fail quality control. These might include:
>
>- **Text changed, but entities preserved**  
>
>- **Entities altered or incorrectly annotated**  
>
>- **Significant rewrites of the input (both entities and text)**
>

This section evaluates whether the annotated outputs generated by the LLMs preserve the original input text (after stripping tags). We track and record whether each annotated file was conserved or modified.
At the end of each prompt's processing loop, we print:
- Total number of texts processed

- Number of conserved outputs

- Number of modified outputs


This is based on **the type of prompt** (zero-shot, one-shot, and few-shot)

In [265]:
list_of_conserved_llm_texts = []
list_of_conserved_filenames = []
list_of_modified_filenames = []
list_of_modified_llm_texts = []

for prompt in LIST_PROMPTS: # Testing each type of prompt
    print(f"\nPrompt type: {prompt} ==============\n")

    prompt_name = os.path.basename(prompt)
    output_prompt_folder = os.path.join(ouput_llm_annotation_folder, prompt_name)

    # Counter for the number of texts for each prompt
    prompt_total_texts = 0
    prompt_conserved_texts = 0
    prompt_modified_texts = 0

    for model in LIST_MODELS: # Testing each model
        print(f"Model ID: {model}\n")

        output_model_folder = os.path.join(output_prompt_folder, model)

        for filename in os.listdir(output_model_folder): # Loop through the files in the annotations folder
            print(f"Processing file : {filename} ---------")
            input_text, ground_truth_entities = process_json_file(os.path.join(ANNOTATIONS_FOLDER, filename))

            # Load the annotation result
            original_text, annotated_output = load_annotation_result(os.path.join(output_model_folder, filename))

            # Compare the original text to the annotated output
            result = compare_annotated_to_original(original_text, annotated_output)
            print("Original text conserved? -", result, '\n')
            if not result:
                modified_total_texts += 1
                list_of_modified_llm_texts.append(os.path.join(output_model_folder, filename))
                list_of_modified_filenames.append(os.path.join(ANNOTATIONS_FOLDER, filename))
            else:
                conserved_total_texts += 1
                list_of_conserved_llm_texts.append(os.path.join(output_model_folder, filename))
                list_of_conserved_filenames.append(os.path.join(ANNOTATIONS_FOLDER, filename))
            total_texts += 1
    
            # Update the prompt-specific counters
            prompt_total_texts += 1
            if result:
                prompt_conserved_texts += 1
            else:
                prompt_modified_texts += 1

    # Print the results for each prompt
    print("\nResults for prompt:", prompt)
    print(f"Total texts processed: {prompt_total_texts}")
    print(f"Conserved texts: {prompt_conserved_texts}")
    print(f"Modified texts: {prompt_modified_texts}")



Model ID: llama-3.3-70b-versatile

Processing file : figshare_22213635.json ---------
Original text conserved? - True 

Processing file : figshare_4757161.json ---------
Original text conserved? - False 

Processing file : figshare_21263177.json ---------
Original text conserved? - False 

Processing file : zenodo_6582985.json ---------
Original text conserved? - False 

Processing file : zenodo_6478270.json ---------
Original text conserved? - False 

Processing file : figshare_7783568.json ---------
Original text conserved? - True 

Processing file : figshare_20300547.json ---------
Original text conserved? - False 

Processing file : zenodo_4805388.json ---------
Original text conserved? - False 

Processing file : zenodo_51754.json ---------
Original text conserved? - False 

Processing file : figshare_20009556.json ---------
Original text conserved? - False 

Model ID: meta-llama/llama-4-maverick-17b-128e-instruct

Processing file : figshare_22213635.json ---------
Original text

This time, we aggregate quality control results per **model**, allowing us to assess how well each model preserves the original input text across all prompt types.

In [266]:
for model in LIST_MODELS: # Testing each model
    print(f"\nModel ID: {model}")

    model_total_texts = 0
    model_conserved_total_texts = 0
    model_modified_total_texts = 0

    for prompt in LIST_PROMPTS:
        print(f"\nPrompt type: {prompt} ==============")
        prompt_name = os.path.basename(prompt)
        output_prompt_folder = os.path.join(ouput_llm_annotation_folder, prompt_name)
        output_model_folder = os.path.join(output_prompt_folder, model)

        # Loop through the files in the folder
        for filename in os.listdir(output_model_folder):
            print(f"\nProcessing file : {filename} ---------")
            input_text, ground_truth_entities = process_json_file(os.path.join(ANNOTATIONS_FOLDER, filename))
            # Load the annotation result
            original_text, annotated_output = load_annotation_result(os.path.join(output_model_folder, filename))
            # Compare the original text to the annotated output
            result = compare_annotated_to_original(original_text, annotated_output)
            print("Original text conserved? -", result)
            if not result:
                model_modified_total_texts += 1
            else:
                model_conserved_total_texts += 1
            model_total_texts += 1

    print("\nResults for prompt:", model)
    print(f"Total texts processed: {model_total_texts}")
    print(f"Conserved texts: {model_conserved_total_texts}")
    print(f"Modified texts: {model_modified_total_texts}")



Model ID: llama-3.3-70b-versatile


Processing file : figshare_22213635.json ---------
Original text conserved? - True

Processing file : figshare_4757161.json ---------
Original text conserved? - False

Processing file : figshare_21263177.json ---------
Original text conserved? - False

Processing file : zenodo_6582985.json ---------
Original text conserved? - False

Processing file : zenodo_6478270.json ---------
Original text conserved? - False

Processing file : figshare_7783568.json ---------
Original text conserved? - True

Processing file : figshare_20300547.json ---------
Original text conserved? - False

Processing file : zenodo_4805388.json ---------
Original text conserved? - False

Processing file : zenodo_51754.json ---------
Original text conserved? - False

Processing file : figshare_20009556.json ---------
Original text conserved? - False


Processing file : figshare_22213635.json ---------
Original text conserved? - False

Processing file : figshare_4757161.json -----

## **LLM annotations scoring**
---

To assess the quality of entity annotations produced by different LLMs, we implement a set of evaluation metrics that allow both quantitative and qualitative analysis. We aim to measure how well each model performs in identifying and labeling entities.

However, before we can properly assess the quality of the annotations, we need to extarct the entities and store them in a standard structure.
Both the ground-truth entities and the llm-annotated entities will be in the following structure:

```json
{
  "MOL": ["arylamide", "hDM2", "p53", "Nutlin-2", "benzodiazepinedione", "p53 helix"],
  "SOFTNAME": ["GAFF", "AutoDock"],
  "SOFTVERS": [],
  "STIME": ["20 ns"],
  "TEMP": [],
  "FFM": []
}
```

For this, we need different helper functions that will: :

- convert the current **ground-truth annotation** format to the one we want


Current ground-truth annotation format:
```json
{
	"classes": ["TEMP", "SOFT", "STIME", "MOL", "FFM"],
	"annotations": [[
    	"An in silico approach to determine inter-subunit affinities in human septin complexes.",
    	{"entities": [[69, 75, "MOL"], [90, 97, "MOL"], [1255, 1260, "MOL"], [1368, 1374, "MOL"]]}
	]]
}
```

- convert the **llm-ouput annotation** format to the one we want

llm-ouput annotation format:
```json
{
  "model": "gemma2-9b-it",
  "text_to_annotate": "Extending the Stochastic Titration CpHMD to CHARMM36m.",
  "response": "Extending the Stochastic Titration CpHMD to <FFM>CHARMM36m</FFM>."
}
```

In [267]:
def process_llm_json_file(json_file: str) -> tuple:
    with open(json_file, "r") as f:
        data = json.load(f)

    # Extract the input text, response, and model
    text_to_annotate = data["text_to_annotate"]
    response = data["response"]
    model = data["model"]

    return text_to_annotate, response, model

def extract_entities_from_llm_text(text: str) -> dict:
    """
    Extract entities from an output text based on tagged annotations.
    
    The input text is expected to have entities enclosed in tags, e.g.:
    "Extending the Stochastic Titration CpHMD to <FFM>CHARMM36m</FFM> using <SOFTNAME>Gromacs</SOFTNAME>"
    
    The function returns a dictionary with keys corresponding to the desired entity types
    and values as lists with the extracted entity content.
    """
    # Initialize the results with empty lists for all desired keys.
    result = {key: [] for key in TAGS}
    
    # Use a regex to capture tags in the format <TAG>content</TAG>
    # The regex uses a backreference to ensure matching closing tag.
    pattern = re.compile(r"<([A-Z]+)>(.*?)</\1>")
    
    # Find all matches in the text.
    for tag, content in pattern.findall(text):
        # If the tag is one of our desired keys, append the content (stripped of whitespace)
        if tag in result:
            result[tag].append(content.strip())
    
    return result

def extract_entities_from_annotation(text: str, entities: list) -> dict:
    """
    Extract entities from the given text based on a direct list of annotation triples.

    The entities input should be a list of lists formatted as:
    [
        [start_index, end_index, "ENTITY_TYPE"],
        ...
    ]
    
    The function extracts the substring from the text using the provided character indices
    and groups the results by the entity type according to TAGS.
    If an entity type is not in TAGS, it will be ignored.
    If no entities are found for a type, its output list will remain empty.
    
    The function returns a dictionary with keys corresponding to the desired entity types
    and values as lists with the extracted entity content.
    """
    # Initialize the output dictionary with empty lists for each desired key.
    result = {key: [] for key in TAGS}
    
    # Iterate over each entity annotation.
    for start, end, entity_type in entities:
        if entity_type == 'SOFT':
            entity_type = 'SOFTNAME'

        extracted = text[start:end]
        result[entity_type].append(extracted)
    
    return result


In [268]:
# Example usage
filename = "figshare_4757161.json"
input_text, ground_truth_entities = process_json_file(os.path.join(ANNOTATIONS_FOLDER, filename))

extracted = extract_entities_from_annotation(input_text, ground_truth_entities)
print("Ground-truth entities:", extracted)

one_json_example = list_of_conserved_llm_texts[0]
print(one_json_example)

_, response, _ = process_llm_json_file(one_json_example)
llm_extracted = extract_entities_from_llm_text(response)
print("LLM extracted entities:", llm_extracted)

Ground-truth entities: {'MOL': ['Enzyme', 'Polystyrene', 'styrene', 'STY', 'CAT', 'SPA', 'methyl', 'Styrene', '2-phenylpropane', 'methyl', 'styrene', 'methyl', 'STY', 'styrene', 'styrene', 'methyl', 'styrene', 'methyl', 'CAT', 'styrene', 'SPA', 'styrene', 'alkyl', 'p-divinylbenzene', 'p-divinylbenzene', 'styrene', 'styrene', 'waters', 'styrene', 'CAT', 'SPA'], 'SOFTNAME': ['Amber16', 'Antechamber', 'Gaussian09', 'AmberTools', 'SHAKE', 'PMEMD', 'Amber16', 'AmberTools', 'LEaP', 'Amber'], 'SOFTVERS': [], 'STIME': ['100 ns'], 'TEMP': ['300 K', '300 K', '300 K', '300 K'], 'FFM': ['General Amber Force Field', 'GAFF', 'GAFF', 'TIP3P']}
../output_llm_annotations/zero_shot_prompt_template/llama-3.3-70b-versatile/figshare_22213635.json
LLM extracted entities: {'MOL': ['NH3+', 'H2', 'NH4+', 'H', 'NH3+', 'H2', 'NH4+', 'H', 'NH3+', 'H2'], 'SOFTNAME': [], 'SOFTVERS': [], 'STIME': [], 'TEMP': ['T 130 K'], 'FFM': []}


Now onto the scoring. **Evaluation logic:**



1. **Exact match scoring**: Entity is correct if string and type match exactly.


2. **Confidence score**: Fraction of LLMs that agreed on the same entity. (!!! tricky because not all the llms will conserve the text) - ***Not added in yet***


3. **Detection ratio**: Correct entities found vs. total ground truth.


4. **False positives**: Entities predicted but not in ground truth.


5. **False negatives**: Ground truth entities missed by LLM.


6. **Per-type breakdown**: Scores computed by entity type.

In [272]:
# Calculate the exact match score
def exact_match_score(ground_truth: Dict[str, List[str]], predicted: Dict[str, List[str]]) -> Tuple[int, int, float]:
    """
    Computes the exact match score across all types.
    
    - An entity is an exact match if both its string and type match.
    - Returns a tuple of (matched_count, total_ground_truth_count, ratio).
    
    Parameters:
        ground_truth (dict): Ground truth annotations.
        predicted (dict): Predicted annotations.

    Returns:
        tuple: (number of exact matches, total ground truth entities, score ratio)
    """
    matched = 0
    total = 0
    # print(ground_truth.items())
    for entity_type, gt_entities in ground_truth.items():
        
        total += len(gt_entities)
        pred_entities = set(predicted.get(entity_type, []))
        
        # Count only those ground truth entities that appear exactly in the predictions.
        for entity in gt_entities:
            if entity in pred_entities:
                matched += 1
                
    score_ratio = matched / total if total > 0 else 0
    return matched, total, score_ratio


def detection_ratio(ground_truth: Dict[str, List[str]], predicted: Dict[str, List[str]]) -> Dict[str, float]:
    """
    Computes the detection ratio per entity type.
    
    - For each entity type, computes the fraction of ground truth entities that were found in the predicted entities.
    
    Parameters:
        ground_truth (dict): Ground truth annotations.
        predicted (dict): Predicted annotations.
    
    Returns:
        dict: Mapping from entity type to detection ratio (0 to 1).
    """
    ratios = {}
    for entity_type, gt_entities in ground_truth.items():
        pred_entities = set(predicted.get(entity_type, []))
        if gt_entities:
            detected = sum(1 for entity in gt_entities if entity in pred_entities)
            ratios[entity_type] = detected / len(gt_entities)
        else:
            ratios[entity_type] = None  # Undefined (or could be set to 0) if no ground truth for the type.
    return ratios


def false_positives(ground_truth: Dict[str, List[str]], predicted: Dict[str, List[str]]) -> Dict[str, List[str]]:
    """
    Computes false positive entities per entiy type.
    
    - False positive: An entity predicted that is not present in the corresponding ground truth.
    
    Parameters:
        ground_truth (dict): Ground truth annotations.
        predicted (dict): Predicted annotations.
    
    Returns:
        dict: Mapping from entity type to a list of false positive entities.
    """
    false_positives = {}
    for entity_type, pred_entities in predicted.items():
        gt_entities = set(ground_truth.get(entity_type, []))
        # Any predicted entity not in ground truth is a false positive.
        false_positives[entity_type] = [entity for entity in pred_entities if entity not in gt_entities]
    return false_positives


def false_negatives(ground_truth: Dict[str, List[str]], predicted: Dict[str, List[str]]) -> Dict[str, List[str]]:
    """
    Computes false negative entities per entity type.
    
    - False negative: A ground truth entity that was missed by prediction.
    
    Parameters:
        ground_truth (dict): Ground truth annotations.
        predicted (dict): Predicted annotations.
    
    Returns:
        dict: Mapping from entity type to a list of false negative entities.
    """
    false_negatives = {}
    for entity_type, gt_entities in ground_truth.items():
        pred_entities = set(predicted.get(entity_type, []))
        # Any ground truth entity not found in predictions is a false negative.
        false_negatives[entity_type] = [entity for entity in gt_entities if entity not in pred_entities]
    return false_negatives


def per_type_breakdown(ground_truth: Dict[str, List[str]], predicted: Dict[str, List[str]]) -> Dict[str, Dict[str, any]]:
    """
    Provides a detailed breakdown per entity type.
    
    For each entity type, returns a dict with:
      - 'exact_matches': number of exact matches,
      - 'total_gt': total number of ground truth entities,
      - 'detection_ratio': fraction of ground truth detected,
      - 'false_positives': list of false positive entities,
      - 'false_negatives': list of false negative entities.
    
    Parameters:
        ground_truth (dict): Ground truth annotations.
        predicted (dict): Predicted annotations.
    
    Returns:
        dict: Breakdown per entity type.
    """
    breakdown = {}
    for entity_type in set(ground_truth.keys()).union(set(predicted.keys())):
        gt_entities = ground_truth.get(entity_type, [])
        pred_entities = predicted.get(entity_type, [])
        gt_set = set(gt_entities)
        pred_set = set(pred_entities)
        
        exact_match_count = sum(1 for e in gt_entities if e in pred_set)
        total_gt = len(gt_entities)
        detection = exact_match_count / total_gt if total_gt > 0 else None
        
        breakdown[entity_type] = {
            'exact_matches': exact_match_count,
            'total_gt': total_gt,
            'detection_ratio': detection,
            'false_positives': len([e for e in pred_entities if e not in gt_set]),
            'false_negatives': len([e for e in gt_entities if e not in pred_set])
        }
        
    return breakdown


In [270]:
# for i in range(len(list_of_conserved_llm_texts)):
#     print("LLM text path:", list_of_conserved_llm_texts[i])
#     print("Ground truth text path:", list_of_conserved_filenames[i],"\n")

In [273]:
for i in range(len(list_of_conserved_llm_texts)):
    llm_filename = list_of_conserved_llm_texts[i]
    gt_filename = list_of_conserved_filenames[i]

    # Process the ground-truth JSON file and extract entities
    input_text, ground_truth_entities = process_json_file(gt_filename)
    gt_extracted = extract_entities_from_annotation(input_text, ground_truth_entities)
    # print("Ground-truth entities:", gt_extracted)

    # Process the LLM JSON file and extract entities
    _, response, _ = process_llm_json_file(llm_filename)
    llm_extracted = extract_entities_from_llm_text(response)
    # print("LLM extracted entities:", llm_extracted)
    

    # Calculate the exact match score ========================================
    matched, total, score_ratio = exact_match_score(gt_extracted, llm_extracted)
    
    # False positives ========================================================
    fps = false_positives(gt_extracted, llm_extracted)
    
    # False negatives ========================================================
    fns = false_negatives(gt_extracted, llm_extracted)
    
    # Calculate the detection ratio ==========================================
    detect_ratio = detection_ratio(gt_extracted, llm_extracted)

    # Per-type breakdown =====================================================
    breakdown = per_type_breakdown(gt_extracted, llm_extracted)


    # Print results ==========================================================
    print(llm_filename, "\n")
    print(f"★ Exact match score: {matched}/{total} ({score_ratio:.2f})")

    print("\n★ False positives (hallucination ?):")
    for etype, fp_list in fps.items():
        print("  {}: {}".format(etype, fp_list))

    print("\n★ False negatives (missed ?):")
    for etype, fn_list in fns.items():
        print("  {}: {}".format(etype, fn_list))

    print("\n★ Detection ratio per type (# of correct entities found by LLM ÷ # of entities in the ground truth):")
    for etype, ratio in detect_ratio.items():
        print(f"  {etype}: {ratio}")

    print("\n★ Per-type breakdown:")
    for etype, stats in breakdown.items():
        print("  {}: {}".format(etype, stats))
    
    print("\n","="*200, "\n")


../output_llm_annotations/zero_shot_prompt_template/llama-3.3-70b-versatile/figshare_22213635.json 

★ Exact match score: 9/10 (0.90)

★ False positives (hallucination ?):
  MOL: []
  SOFTNAME: []
  SOFTVERS: []
  STIME: []
  TEMP: ['T 130 K']
  FFM: []

★ False negatives (missed ?):
  MOL: []
  SOFTNAME: []
  SOFTVERS: []
  STIME: []
  TEMP: ['130 K']
  FFM: []

★ Detection ratio per type (# of correct entities found by LLM ÷ # of entities in the ground truth):
  MOL: 1.0
  SOFTNAME: None
  SOFTVERS: None
  STIME: None
  TEMP: 0.0
  FFM: None

★ Per-type breakdown:
  STIME: {'exact_matches': 0, 'total_gt': 0, 'detection_ratio': None, 'false_positives': 0, 'false_negatives': 0}
  TEMP: {'exact_matches': 0, 'total_gt': 1, 'detection_ratio': 0.0, 'false_positives': 1, 'false_negatives': 1}
  FFM: {'exact_matches': 0, 'total_gt': 0, 'detection_ratio': None, 'false_positives': 0, 'false_negatives': 0}
  SOFTVERS: {'exact_matches': 0, 'total_gt': 0, 'detection_ratio': None, 'false_positives

# **WE NEED TO SAVE CERTAIN STATS IN A CSV FILE SO THAT WE CAN PLOT SOME GRAPHS OUT OF THEM !!!**

- QUALITY CONTROL IS TOO STRICT