# Knowledge Graph Generation: A/B Prompt Comparison

This notebook conducts an experiment to compare the performance of two different system prompts for generating knowledge graphs from text using a Large Language Model (LLM). It performs the following steps:

1.  **Loads** the WikiGraphs dataset.
2.  **Selects** a sample text and its corresponding ground truth graph.
3.  **Generates** a knowledge graph using a **default system prompt**.
4.  **Generates** a second knowledge graph using a **user-supplied system prompt** for comparison.
5.  **Compares** the two generated graphs quantitatively (number of nodes and edges).
6.  **Evaluates** the performance of both prompts against the ground truth graph using Precision, Recall, and F1-Score for entity linking and relation extraction.

## 1. Setup and Dependencies

In [2]:
# Ensure all required libraries are installed
# !pip install pandas ollama chromadb langchain tqdm wikigraphs-data

In [3]:
import os
import json
import pandas as pd
import ollama
from itertools import chain
from tqdm.auto import tqdm
from pathlib import Path
from langchain.text_splitter import RecursiveCharacterTextSplitter

# Assuming 'wikigraphs' module is in the same directory or Python path
from wikigraphs.data import paired_dataset

pd.set_option('display.max_colwidth', 300)

## 2. Configuration
Modify the variables in this section to customize the experiment.

In [4]:
# Select which sample from the dataset to run the experiment on (e.g., 0 for the first one)
SAMPLE_INDEX = 0

# The LLM model to use with Ollama
LLM_MODEL = "llama3.1:8b"

# Path to the WikiGraphs processed data
WIKIGRAPHS_DATA_DIR = "data/wikigraphs/"

## 3. System Prompt Definitions
Here we define the two system prompts that will be compared. The `USER_SUPPLIED_SYS_PROMPT` can be modified to test different prompting strategies.

In [12]:
# The default system prompt from the original thesis notebook.
DEFAULT_SYS_PROMPT = (
    "You are a knowledge graph maker who extracts terms and their relations from a given context. "
    "You are provided with a context chunk (delimited by ```) Your task is to extract the ontology "
    "of terms mentioned in the given context. These terms should represent the key concepts as per the context. \n"
    "Thought 1: While traversing through each sentence, Think about the key terms mentioned in it.\n"
    "\tTerms may include object, entity, location, organization, person, \n"
    "\tcondition, acronym, documents, service, concept, etc.\n"
    "\tTerms should be as atomistic as possible\n\n"
    "Thought 2: Think about how these terms can have one on one relation with other terms.\n"
    "\tTerms that are mentioned in the same sentence or the same paragraph are typically related to each other.\n"
    "\tTerms can be related to many other terms\n\n"
    "Thought 3: Find out the relation between each such related pair of terms. \n\n"
    "Format your output as a list of json. Each element of the list contains a pair of terms"
    "and the relation between them, like the follwing: \n"
    "[\n"
    "   {\n"
    '       "node_1": "A concept from extracted ontology",\n'
    '       "node_2": "A related concept from extracted ontology",\n'
    '       "edge": "relationship between the two concepts, node_1 and node_2 in one or two sentences"\n'
    "   }, {...}\n"
    "]\n"
    "Do not add any other comment before or after the json. Respond ONLY with a well formed json that can be directly read by a program."
)

# <<< EDIT THIS PROMPT FOR YOUR EXPERIMENT >>>
# This prompt is designed to be more stringent about entity consistency and output format.
USER_SUPPLIED_SYS_PROMPT = (
    "You are a knowledge graph maker who extracts terms and their relations from a given context. "
    "You are provided with a context chunk (delimited by ```) Your task is to extract the ontology "
    "of terms mentioned in the given context. These terms should represent the key concepts as per the context. \n"
    "Thought 1: While traversing through each sentence, Think about the key terms mentioned in it.\n"
    "\tTerms may include object, entity, location, organization, person, \n"
    "\tcondition, acronym, documents, service, concept, etc.\n"
    "\tTerms should be as atomistic as possible\n\n"
    "Thought 2: Think about how these terms can have one on one relation with other terms.\n"
    "\tTerms that are mentioned in the same sentence or the same paragraph are typically related to each other.\n"
    "\tTerms can be related to many other terms\n\n"
    "Thought 3: Find out the relation between each such related pair of terms.\n"
    "Thought 4: When you extract triples, you must resolve all pronouns (like 'he', 'she', 'it', 'they') to the "
    "specific named entity they refer to based on the context of the passage. Do not use pronouns in the subject"
    "or object of your output triples.\n"
    "Thought 5: If an entity is mentioned multiple times but with different names, when noting that entity into a triplet "
    "use the most uniquely distinguishable name among those referenced in the text and be consistent to replace that name to "
    "all mentions of that entity in the text.\n\n"
    "Format your output as a list of json. Each element of the list contains a pair of terms"
    "and the relation between them, like the follwing: \n"
    "[\n"
    "   {\n"
    '       "node_1": "A concept from extracted ontology",\n'
    '       "node_2": "A related concept from extracted ontology",\n'
    '       "edge": "relationship between the two concepts, node_1 and node_2 in one or two sentences"\n'
    "   }, {...}\n"
    "]\n"
    "Do not add any other comment before or after the json. Respond ONLY with a well formed json that can be directly read by a program."
)

# USER_SUPPLIED_SYS_PROMPT = (
#     "You are an expert knowledge graph extractor. Your task is to analyze the provided text and "
#     "extract relationships as (Subject, Predicate, Object) triplets. "
#     "CRITICAL INSTRUCTIONS:\n"
#     "1. **Consistent Naming**: If an entity is mentioned multiple times (e.g., 'Jacqueline Fernandez', 'Fernandez'), "
#     "always use the most complete name ('Jacqueline Fernandez') for the node to ensure consistency.\n"
#     "2. **Atomicity**: Nodes should represent single, atomic concepts.\n"
#     "3. **Output Format**: Your final output must be ONLY a valid JSON list of dictionaries. Each dictionary must "
#     "contain exactly three keys: 'node_1' (Subject), 'node_2' (Object), and 'edge' (Predicate).\n"
#     "Example format:\n"
#     "[\n"
#     '   {"node_1": "Jacqueline Fernandez", "node_2": "Sri Lankan actress", "edge": "is a"},\n'
#     '   {"node_1": "Jacqueline Fernandez", "node_2": "Manama, Bahrain", "edge": "was born in"}\n'
#     "]\n"
#     "Do not include any explanations, comments, or text outside of the JSON list."
# )

## 4. Helper Functions for Data Loading and Graph Generation

In [6]:
def load_wikigraphs_data(data_root, subset='train', version='max256'):
    """Loads the parsed WikiGraphs dataset."""
    print("Loading WikiGraphs dataset...")
    paired_dataset.DATA_ROOT = data_root
    dataset = paired_dataset.ParsedDataset(
        subset=subset,
        shuffle_data=False,
        data_dir=None,
        version=version
    )
    parsed_pairs = list(dataset)
    print(f"Loaded {len(parsed_pairs)} pairs from the dataset.")
    return parsed_pairs

def get_ground_truth_graph(pair):
    """Converts a WikiGraphs pair object into a pandas DataFrame of triplets."""
    g = pair.graph
    df = pd.DataFrame(g.edges(), columns=["src", "tgt", "edge"])
    df["subject"] = df["src"].apply(lambda node_id: g.nodes()[node_id])
    df["object"] = df["tgt"].apply(lambda node_id: g.nodes()[node_id])
    df = df[["subject", "edge", "object"]]
    df.rename(columns={"edge": "predicate"}, inplace=True)
    return df.drop_duplicates().reset_index(drop=True)

def fix_llm_json_output(text):
    """Fixes the LLM prompt JSON output by removing text that would produce a malfromed JSON."""
    # any line not starting with these characters is an LLM comment and not a JSON text line
    starting_characters = ('"', '{', '}', '[', ']')
    lines = text.splitlines()
    filtered_lines = [line for line in lines if any(line.lstrip().startswith(char) for char in starting_characters)]
    return "\n".join(filtered_lines)

def generate_graph_from_text(text, system_prompt, model=LLM_MODEL):
    """Generates a knowledge graph from a given text using an LLM."""
    splitter = RecursiveCharacterTextSplitter(
        chunk_size=1500, chunk_overlap=150, length_function=len
    )
    pages = splitter.split_text(text)
    
    all_triplets = []
    print(f"Processing text in {len(pages)} chunks...")
    for page in tqdm(pages, desc="Generating graph triplets"):
        if len(page.strip()) < 50:
            continue
            
        user_prompt = f"context: ```{page}``` \n\n output: "
        try:
            response_dict = ollama.generate(
                model=model,
                system=system_prompt,
                prompt=user_prompt
            )
            response_text = response_dict["response"]
            cleaned_response = fix_llm_json_output(response_text)
            triplets = json.loads(cleaned_response)
            
            if isinstance(triplets, list) and all(isinstance(i, dict) for i in triplets):
                all_triplets.extend(triplets)
        except json.JSONDecodeError:
            print(f"Warning: Failed to decode JSON from LLM response for a chunk.")
            continue
        except Exception as e:
            print(f"An unexpected error occurred during LLM call: {e}")
            continue

    if not all_triplets:
        return pd.DataFrame(columns=['node_1', 'node_2', 'edge'])
        
    valid_triplets = [t for t in all_triplets if all(k in t for k in ['node_1', 'node_2', 'edge'])]
    df = pd.DataFrame(valid_triplets)
    return df.drop_duplicates().reset_index(drop=True)

## 5. Experiment Execution

### Step 1: Load the Dataset

In [7]:
parsed_pairs = load_wikigraphs_data(WIKIGRAPHS_DATA_DIR)

Loading WikiGraphs dataset...
Loaded 23431 pairs from the dataset.


### Step 2: Select Sample Text and Ground Truth

In [8]:
sample_pair = parsed_pairs[SAMPLE_INDEX]
sample_text = sample_pair.text

print(f"--- Analyzing Sample: '{sample_pair.title}' ---")
print(f"Text length: {len(sample_text)} characters")
print("\n--- First 500 characters of text ---")
print(sample_text[:500])

--- Analyzing Sample: 'Valkyria_Chronicles_III' ---
Text length: 20946 characters

--- First 500 characters of text ---

 = Valkyria Chronicles III = 
 
 Senjō no Valkyria 3 : Unrecorded Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit . Valkyria of the Battlefield 3 ) , commonly referred to as Valkyria Chronicles III outside Japan , is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable . Released in January 2011 in Japan , it is the third game in the Valkyria series . Employing the same fusion of tactical and real @-@ time gameplay as its predecessors , the story runs


In [9]:
ground_truth_df = get_ground_truth_graph(sample_pair)
print("\n--- Ground Truth Graph (First 5 Rows) ---")
display(ground_truth_df.head())


--- Ground Truth Graph (First 5 Rows) ---


Unnamed: 0,subject,predicate,object
0,ns/m.0f9q9z,key/wikipedia.en,"""Sega_AM1"""
1,ns/m.0f9q9z,ns/type.object.name,"""Sega Wow"""
2,ns/m.0f9q9z,ns/organization.organization.date_founded,"""2000"""
3,ns/m.0f9q9z,ns/cvg.cvg_developer.games_developed,ns/m.0ddd390
4,ns/m.0f9q9z,ns/common.topic.description,"""Sega Wow was a division of Japanese video game developer Sega."""


### Step 3: Generate Graph with Default Prompt

In [10]:
print("--- Generating Graph with Default System Prompt ---")
graph_default_df = generate_graph_from_text(sample_text, DEFAULT_SYS_PROMPT)

--- Generating Graph with Default System Prompt ---
Processing text in 21 chunks...


Generating graph triplets:   0%|          | 0/21 [00:00<?, ?it/s]


--- Generated Graph with Default Prompt (First 5 Rows) ---


Unnamed: 0,node_1,node_2,edge
0,Valkyria Chronicles III,Senjō no Valkyria 3,same game
1,Sega,Media.Vision,co-developers of the game
2,PlayStation Portable,Valkyria Chronicles III,game platform
3,January 2011,Japan,release date and location
4,Valkyria Chronicles III,third game in the Valkyria series,position in the series


In [17]:
print("\n--- Generated Graph with Default Prompt (First 50 Rows) ---")
display(graph_default_df.head(50))


--- Generated Graph with Default Prompt (First 5 Rows) ---


Unnamed: 0,node_1,node_2,edge
0,Valkyria Chronicles III,Senjō no Valkyria 3,same game
1,Sega,Media.Vision,co-developers of the game
2,PlayStation Portable,Valkyria Chronicles III,game platform
3,January 2011,Japan,release date and location
4,Valkyria Chronicles III,third game in the Valkyria series,position in the series
5,Nameless,penal military unit,unit type
6,Gallia,Second Europan War,nation involved in the war
7,Nameless,secret black operations,type of operation
8,Calamaty Raven,Imperial unit,unit type
9,Raita Honjou,Character designer,character's role


### Step 4: Generate Graph with User-Supplied Prompt

In [None]:
print("--- Generating Graph with User-Supplied System Prompt ---")
graph_user_df = generate_graph_from_text(sample_text, USER_SUPPLIED_SYS_PROMPT)

In [18]:
print("\n--- Generated Graph with User Prompt (First 50 Rows) ---")
display(graph_user_df.head(50))


--- Generated Graph with User Prompt (First 5 Rows) ---


Unnamed: 0,node_1,node_2,edge
0,Valkyria Chronicles III,Japanese title: Senjō no Valkyria 3,is the Japanese name for
1,Valkyria Chronicles III,PS-P game,is developed for
2,Valkyria Chronicles III,Second Europan War,takes place during
3,Nameless,penal military unit,is a type of
4,Nameless,Gallia nation,serves in the military for
5,Calamaty Raven,Imperial unit,is an enemy of
6,Raita Honjou,Character designer,has a role as
7,Hitoshi Sakimoto,Composer,has a role as
8,May 'n,Singer of the game's opening theme,performed for
9,Valkyria Chronicles,Japan,released in Japan


## 6. Comparison and Performance Evaluation

Now we define the functions to compare the generated graphs and measure their performance against the ground truth.

In [32]:
def compare_generated_graphs(graph1_df, graph2_df, prompt1_name="Default", prompt2_name="User"):
    """Performs a quantitative comparison between two generated graphs."""
    print("\n--- Quantitative Graph Comparison ---")
    
    stats1 = {
        "nodes": len(pd.concat([graph1_df['node_1'], graph1_df['node_2']]).unique()),
        "edges": len(graph1_df)
    }
    stats2 = {
        "nodes": len(pd.concat([graph2_df['node_1'], graph2_df['node_2']]).unique()),
        "edges": len(graph2_df)
    }
    
    print(f"Graph from '{prompt1_name}' Prompt: {stats1['nodes']} unique nodes, {stats1['edges']} edges (triplets).")
    print(f"Graph from '{prompt2_name}' Prompt: {stats2['nodes']} unique nodes, {stats2['edges']} edges (triplets).")

def calculate_graph_metrics(generated_df, truth_df):
    """Calculates precision, recall, and F1 score for triplets."""
    # Normalize strings for fair comparison
    def normalize_df(df, cols):
        df_copy = df.copy()
        for col in cols:
            df_copy[col] = df_copy[col].str.lower().str.strip()
        return df_copy

    generated_norm = normalize_df(generated_df, ['node_1', 'node_2', 'edge'])
    truth_norm = normalize_df(truth_df, ['subject', 'predicate', 'object'])
    truth_norm.columns = ['node_1', 'edge', 'node_2'] # Align column names

    # Convert dataframes to sets of tuples for efficient comparison
    generated_triplets = set(map(tuple, generated_norm.to_records(index=False)))
    truth_triplets = set(map(tuple, truth_norm.to_records(index=False)))

    # --- Triplet-level Metrics ---
    true_positives = len(generated_triplets.intersection(truth_triplets))
    
    precision = true_positives / len(generated_triplets) if len(generated_triplets) > 0 else 0
    recall = true_positives / len(truth_triplets) if len(truth_triplets) > 0 else 0
    f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    
    return {
        "precision": precision,
        "recall": recall,
        "f1_score": f1_score,
        "true_positives": true_positives,
        "generated_count": len(generated_triplets),
        "truth_count": len(truth_triplets)
    }

def compare_coreference_resolution(default_graph_df, user_graph_df, prompt1_name="Default Prompt", prompt2_name="User-Supplied Prompt"):
    """
    Compares two generated graphs to analyze their relative coreference resolution performance.

    This function does NOT use a ground truth graph. Instead, it directly compares the
    entity vocabularies of the two generated graphs. The primary assumption is that a
    graph with a smaller, more consolidated set of unique entities demonstrates
    better coreference resolution.

    Args:
        default_graph_df (pd.DataFrame): The graph generated by the default prompt.
        user_graph_df (pd.DataFrame): The graph generated by the user-supplied prompt.
        prompt1_name (str): The name for the default prompt's graph.
        prompt2_name (str): The name for the user-supplied prompt's graph.
    """
    print("\n" + "="*25 + " Direct Coreference Resolution Comparison " + "="*25)
    print("This analysis compares the two generated graphs against each other.")
    print("A lower 'Total Unique Entities' count suggests better entity consolidation.\n")

    # --- Step 1: Extract and normalize unique entities from each graph ---
    
    # Extract entities from the Default Prompt's graph
    default_entities = set(pd.concat([default_graph_df['node_1'], default_graph_df['node_2']])
                           .astype(str).str.lower().str.strip().unique())
    default_entity_count = len(default_entities)

    # Extract entities from the User-Supplied Prompt's graph
    user_entities = set(pd.concat([user_graph_df['node_1'], user_graph_df['node_2']])
                        .astype(str).str.lower().str.strip().unique())
    user_entity_count = len(user_entities)
    
    # --- Step 2: Create a comparative summary ---
    
    summary_data = {
        "Metric": ["Total Unique Entities Generated", "Entities Common to Both Graphs"],
        prompt1_name: [default_entity_count, len(default_entities.intersection(user_entities))],
        prompt2_name: [user_entity_count, len(user_entities.intersection(default_entities))]
    }
    summary_df = pd.DataFrame(summary_data).set_index("Metric")
    display(summary_df)

    # --- Step 3: Provide a clear, automated interpretation ---
    
    print("\n--- Interpretation of Coreference Performance ---")
    
    if user_entity_count < default_entity_count:
        improvement = (default_entity_count - user_entity_count) / default_entity_count * 100
        print(f"Conclusion: The '{prompt2_name}' prompt shows superior coreference resolution.")
        print(f"It generated {default_entity_count - user_entity_count} fewer unique entities ({improvement:.2f}% reduction), indicating better consolidation.")
    elif default_entity_count < user_entity_count:
        degradation = (user_entity_count - default_entity_count) / user_entity_count * 100
        print(f"Conclusion: The '{prompt1_name}' prompt shows superior coreference resolution.")
        print(f"It generated {user_entity_count - default_entity_count} fewer unique entities ({degradation:.2f}% reduction), indicating better consolidation.")
    else:
        print("Conclusion: Both prompts resulted in the exact same number of unique entities.")
        print("Their coreference resolution performance appears to be identical based on this metric.")
    
    print("=" * 82)

def run_full_evaluation(graph_df, truth_df, prompt_name):
    """Runs the full suite of evaluation metrics for a given graph."""
    print(f"\n{'='*20} EVALUATION FOR '{prompt_name}' PROMPT {'='*20}")
    
    # Evaluate triplet extraction performance
    metrics = calculate_graph_metrics(graph_df, truth_df)
    
    print("\n--- Relation Extraction Performance (Triplets) ---")
    print(f"Correct Triplets (True Positives): {metrics['true_positives']}")
    print(f"Total Generated Triplets: {metrics['generated_count']}")
    print(f"Total Ground Truth Triplets: {metrics['truth_count']}")
    print("--------------------------------------------------")
    print(f"Precision: {metrics['precision']:.4f}")
    print(f"Recall:    {metrics['recall']:.4f}")
    print(f"F1-Score:  {metrics['f1_score']:.4f}")
    print("="*65 + "\n")

### Step 5: Run Final Comparison and Evaluation

In [33]:
# Perform a direct quantitative comparison
compare_generated_graphs(graph_default_df, graph_user_df)

# Perform the direct comparison for coreference resolution
compare_coreference_resolution(graph_default_df, 
                               graph_user_df, 
                               prompt1_name="Default Prompt", 
                               prompt2_name="User-Supplied Prompt")

# Evaluate the graph from the default prompt
run_full_evaluation(graph_default_df, ground_truth_df, "Default")

# Evaluate the graph from the user-supplied prompt
run_full_evaluation(graph_user_df, ground_truth_df, "User-Supplied")


--- Quantitative Graph Comparison ---
Graph from 'Default' Prompt: 249 unique nodes, 197 edges (triplets).
Graph from 'User' Prompt: 216 unique nodes, 172 edges (triplets).

This analysis compares the two generated graphs against each other.
A lower 'Total Unique Entities' count suggests better entity consolidation.



Unnamed: 0_level_0,Default Prompt,User-Supplied Prompt
Metric,Unnamed: 1_level_1,Unnamed: 2_level_1
Total Unique Entities Generated,248,215
Entities Common to Both Graphs,154,154



--- Interpretation of Coreference Performance ---
Conclusion: The 'User-Supplied Prompt' prompt shows superior coreference resolution.
It generated 33 fewer unique entities (13.31% reduction), indicating better consolidation.


--- Relation Extraction Performance (Triplets) ---
Correct Triplets (True Positives): 0
Total Generated Triplets: 197
Total Ground Truth Triplets: 55
--------------------------------------------------
Precision: 0.0000
Recall:    0.0000
F1-Score:  0.0000



--- Relation Extraction Performance (Triplets) ---
Correct Triplets (True Positives): 0
Total Generated Triplets: 172
Total Ground Truth Triplets: 55
--------------------------------------------------
Precision: 0.0000
Recall:    0.0000
F1-Score:  0.0000

