# 0. Libraries

In [1]:
# Utilities
import asyncio
from dotenv import load_dotenv
import os
import json
from google import genai
import polars as pl
from kg_builder import CustomKGPipeline, build_kg_from_df, GeminiLLM

# Neo4j and Neo4j GraphRAG imports
import neo4j
from neo4j_graphrag.embeddings import SentenceTransformerEmbeddings

Let's first check the available Gemini models.

In [3]:
load_dotenv('.env', override=True)

gemini_api_key = os.getenv('GEMINI_API_KEY')

if gemini_api_key:
    client = genai.Client(api_key=gemini_api_key)  # Configure the API key for genai
else:
    raise ValueError("GEMINI_API_KEY environment variable is not set.")

# Display available models
for model in client.models.list():
    print(model)

name='models/embedding-gecko-001' display_name='Embedding Gecko' description='Obtain a distributed representation of a text.' version='001' endpoints=None labels=None tuned_model_info=TunedModelInfo(base_model=None, create_time=None, update_time=None) input_token_limit=1024 output_token_limit=1 supported_actions=['embedText', 'countTextTokens'] default_checkpoint_id=None checkpoints=None
name='models/gemini-1.0-pro-vision-latest' display_name='Gemini 1.0 Pro Vision' description='The original Gemini 1.0 Pro Vision model version which was optimized for image understanding. Gemini 1.0 Pro Vision was deprecated on July 12, 2024. Move to a newer Gemini version.' version='001' endpoints=None labels=None tuned_model_info=TunedModelInfo(base_model=None, create_time=None, update_time=None) input_token_limit=12288 output_token_limit=4096 supported_actions=['generateContent', 'countTokens'] default_checkpoint_id=None checkpoints=None
name='models/gemini-pro-vision' display_name='Gemini 1.0 Pro Vi

We also have to make sure that the corresponding SpaCy model for text embedding used at the resolving step is installed.

In [2]:
import spacy
import importlib.util
import subprocess
import sys

def ensure_spacy_model(model_name):
    if importlib.util.find_spec(model_name) is None:
        print(f"Model '{model_name}' not found. Installing...")
        subprocess.check_call([sys.executable, "-m", "spacy", "download", model_name])
    else:
        print(f"Model '{model_name}' is already installed.")

# Use it for 'en_core_web_lg'
ensure_spacy_model("en_core_web_lg")  # Model used for resolving entities in the KG pipeline

Model 'en_core_web_lg' is already installed.


# 1. Loading the data

The data is loaded here as a reference, but it is loaded again inside the pipeline below.

In [8]:
notebook_dir = os.getcwd()

# Load data
df_path = os.path.join(notebook_dir, 'FILTERED_DATAFRAME.parquet')
df = pl.read_parquet(df_path)

df.head()

state,date,month_year,year,event_code,quad_class,goldstein_scale,avg_tone,actor1_statecode,actor2_statecode,url,title,full_text
str,i64,i64,i64,i64,i64,f64,f64,str,str,str,str,str
"""USMO""",20210514,202105,2021,16,1,-2.0,-8.934073,"""USMO""","""USMO""","""https://www.natlawreview.com/a…","""State of the Law for Business …","""It’s been a year since COVID-1…"
"""USMO""",20210514,202105,2021,141,3,-6.5,-0.808625,"""USMO""","""USMO""","""https://www.kcur.org/health/20…","""Medicaid Expansion Supporters …","""A day after Missouri Gov. Mike…"
"""USMO""",20210529,202105,2021,13,1,0.4,-6.008584,"""USMO""","""USMO""","""https://www.dailystar.co.uk/ne…","""Elderly woman sucker-punched t…","""Elderly woman sucker-punched t…"
"""USAR""",20200207,202002,2020,16,1,-2.0,-8.0,"""USAR""",,"""https://www.houstonchronicle.c…",,
"""USNH""",20201206,202012,2020,70,2,7.0,0.088106,"""USNH""","""USNH""","""https://www.fosters.com/story/…","""Historically Speaking: Adventu…","""Historically Speaking: Adventu…"


There are three fields that can be useful as document metadata:
1. The `title` of the article.
2. The `date` when the article was published.
3. Its `url`.

Let's convert the date to a datetime type.

In [9]:
# First, convert the 'date' column to string format
df = df.with_columns(pl.col('date').cast(pl.String))

# Convert date column to datetime
df = df.with_columns(pl.col('date').str.strptime(pl.Date, format='%Y%m%d'))

df.head()

state,date,month_year,year,event_code,quad_class,goldstein_scale,avg_tone,actor1_statecode,actor2_statecode,url,title,full_text
str,date,i64,i64,i64,i64,f64,f64,str,str,str,str,str
"""USMO""",2021-05-14,202105,2021,16,1,-2.0,-8.934073,"""USMO""","""USMO""","""https://www.natlawreview.com/a…","""State of the Law for Business …","""It’s been a year since COVID-1…"
"""USMO""",2021-05-14,202105,2021,141,3,-6.5,-0.808625,"""USMO""","""USMO""","""https://www.kcur.org/health/20…","""Medicaid Expansion Supporters …","""A day after Missouri Gov. Mike…"
"""USMO""",2021-05-29,202105,2021,13,1,0.4,-6.008584,"""USMO""","""USMO""","""https://www.dailystar.co.uk/ne…","""Elderly woman sucker-punched t…","""Elderly woman sucker-punched t…"
"""USAR""",2020-02-07,202002,2020,16,1,-2.0,-8.0,"""USAR""",,"""https://www.houstonchronicle.c…",,
"""USNH""",2020-12-06,202012,2020,70,2,7.0,0.088106,"""USNH""","""USNH""","""https://www.fosters.com/story/…","""Historically Speaking: Adventu…","""Historically Speaking: Adventu…"


In [5]:
# Keep a subset of rows for testing
df = df.head(100)

# 2. Running the pipeline

In [11]:
# Open configuration file
script_dir = os.getcwd()
config_path = os.path.join(script_dir, 'config.json')
with open(config_path, 'r') as config_file:
    config = json.load(config_file)
prompt_template=config['prompt_template_config']['template'] if config['prompt_template_config'].get('use_default') == False else None
print(prompt_template)

You are a top-tier algorithm designed for extracting information in structured formats to build a knowledge graph that will be used for creating security reports for different countries.

Extract the entities (nodes) and specify their type from the following Input text.
Also extract the relationships between these nodes. The relationship direction goes from the start node to the end node.

Return result as JSON using the following format:
{{"nodes": [ {{"id": "0", "label": "the type of entity", "properties": {{"name": "name of entity" }} }}],
"relationships": [{{"type": "TYPE_OF_RELATIONSHIP", "start_node_id": "0", "end_node_id": "1", "properties": {{"details": "Description of the relationship"}} }}] }}

- Use only the information from the Input text. Do not add any additional information.
- If the input text is empty, return empty Json.
- Make sure to create as many nodes and relationships as needed to offer rich context for generating a security-related knowledge graph.
- An AI knowl

## 2.1. With a single text string

In [2]:
# Example usage with a single text string
async def process_single_text():
    # Load configuration and setup
    script_dir = os.getcwd()
    
    # Load environment variables
    dotenv_path = os.path.join(script_dir, '.env')
    load_dotenv(dotenv_path, override=True)
    
    # Open configuration file
    config_path = os.path.join(script_dir, 'config.json')
    with open(config_path, 'r') as config_file:
        config = json.load(config_file)
    
    # Neo4j connection
    neo4j_uri = os.getenv('NEO4J_URI')
    neo4j_username = os.getenv('NEO4J_USERNAME')
    neo4j_password = os.getenv('NEO4J_PASSWORD')
    gemini_api_key = os.getenv('GEMINI_API_KEY')
    
    if not gemini_api_key:
        raise ValueError("Gemini API key is not set. Please provide a valid API key.")
    
    # Initialize LLM
    llm = GeminiLLM(
        model_name=config['llm_config']['model_name'],
        google_api_key=gemini_api_key,
        model_params=config['llm_config']['model_params']
    )
    
    # Initialize embedder
    embedder = SentenceTransformerEmbeddings(model=config['embedder_config']['model_name'])
    
    # Configure text splitter
    text_splitter_config = config['text_splitter_config']
    
    # Sample text to process - using a short example for testing
    sample_text = """The Secretary-General strongly condemns the attacks on 15 January against the United Nations 
    Multidimensional Integrated Stabilization Mission in Mali (MINUSMA) in Aguelhok, Kidal region. The attacks 
    resulted in the death of four peacekeepers from Chad and serious injuries to nineteen others."""
    
    # Document information
    document_title = "UN Secretary-General Condemns Attack on Peacekeepers in Mali"
    document_metadata = {
        "source": "https://www.un.org/press/en/example-url.html",
        "published_date": "2021-01-16"
    }
    document_id = "test_doc_001"
    
    # Create the pipeline using a with statement for proper resource management
    with neo4j.GraphDatabase.driver(neo4j_uri, auth=(neo4j_username, neo4j_password)) as driver:
        # Initialize the custom KG pipeline
        kg_pipeline = CustomKGPipeline(
            llm=llm,
            driver=driver,
            embedder=embedder,
            schema_config=config['schema_config'],
            prompt_template=config['prompt_template_config']['template'] if config['prompt_template_config'].get('use_default') == False else None,
            text_splitter_config=text_splitter_config,
            perform_entity_resolution=config.get('perform_entity_resolution', True),
            similarity_threshold=0.8,
            on_error='RAISE',
            batch_size=1000,
            max_concurrency=5,
            examples=""
        )
        
        print("Pipeline initialized successfully. Processing text...")
        
        # Process the single text with the pipeline
        result = await kg_pipeline.run_async(
            text=sample_text,
            document_base_field=document_title,
            document_metadata=document_metadata,
            document_id=document_id
        )
        
        print("Text processing complete.")
    
    return result

# Run the function
print("Starting single text processing...")
result = await process_single_text()
print(f"Pipeline completed with result: {result}")

# Optional: If you want to see details of what was created in the knowledge graph
if result and hasattr(result, 'stats'):
    print("\nKG Creation Statistics:")
    for key, value in result.stats.items():
        print(f"  {key}: {value}")

Starting single text processing...


  from .autonotebook import tqdm as notebook_tqdm


Pipeline initialized successfully. Processing text...
Text processing complete.
Pipeline completed with result: run_id='c58760b7-8218-4ca0-9e34-f0d7a4939e21' result={'resolver': {'number_of_nodes_to_resolve': 8, 'number_of_created_nodes': 0}}


## 2.2. With a data frame

In [11]:
# Example usage code
async def main():

    # Load configuration and setup

    script_dir = os.getcwd()

    # script_dir = os.path.dirname(os.path.abspath(__file__))  # Uncomment if running as a script

    # Load environment variables from a .env file
    dotenv_path = os.path.join(script_dir, '.env')
    load_dotenv(dotenv_path, override=True)

    # Open configuration file from JSON format
    config_path = os.path.join(script_dir, 'config.json')
    with open(config_path, 'r') as config_file:
        config = json.load(config_file)
    
    # Neo4j connection
    neo4j_uri = os.getenv('NEO4J_URI')
    neo4j_username = os.getenv('NEO4J_USERNAME')
    neo4j_password = os.getenv('NEO4J_PASSWORD')
    gemini_api_key = os.getenv('GEMINI_API_KEY')
    
    # Check if gemini_api_key is set
    if gemini_api_key:
        pass
    else:
        raise ValueError("Gemini API key is not set. Please provide a valid API key.")

    # Initialize LLM
    llm = GeminiLLM(
        model_name=config['llm_config']['model_name'],
        google_api_key=gemini_api_key,
        model_params=config['llm_config']['model_params']
    )
    
    # Initialize embedder
    embedder = SentenceTransformerEmbeddings(model=config['embedder_config']['model_name'])
    
    # Configure text splitter
    text_splitter_config = config['text_splitter_config']
    
    # Load data
    df_path = os.path.join(script_dir, 'FILTERED_DATAFRAME.parquet')
    df = pl.read_parquet(df_path)

    # Convert 'date' column to string format (from YYYYMMDD to YYYY-MM-DD)
    df = df.with_columns(pl.col('date').cast(pl.String))
    df = df.with_columns(pl.col('date').str.strptime(pl.Date, format='%Y%m%d'))
    df = df.with_columns(pl.col('date').dt.strftime('%Y-%m-%d'))

    # Create subset of the dataframe for testing
    df = df.head(10)
    
    # Create the pipeline - use with statement to ensure proper resource management
    # and to ensure the driver is closed after use
    with neo4j.GraphDatabase.driver(neo4j_uri, auth=(neo4j_username, neo4j_password)) as driver:

        # Initialize the custom KG pipeline
        kg_pipeline = CustomKGPipeline(
            llm=llm,
            driver=driver,
            embedder=embedder,
            schema_config=config['schema_config'],
            prompt_template=config['prompt_template_config']['template'] if config['prompt_template_config'].get('use_default') == False else None,
            text_splitter_config=text_splitter_config,
            perform_entity_resolution=config.get('perform_entity_resolution', True),  # Default to True if not specified
            similarity_threshold=0.8,
            on_error='RAISE',
            batch_size=1000,
            max_concurrency=5,
            examples=""
        )
        
        # Define metadata mapping (document properties additional to base field 
        # to dataframe columns)
        metadata_mapping = {
            "source": "url",
            "published_date": "date"
        }
        
        # Process the dataframe
        results = await build_kg_from_df(
            kg_pipeline=kg_pipeline,
            df=df,
            document_base_field='title',
            text_column='full_text',
            document_metadata_mapping=metadata_mapping,
            document_id_column=None  # Use default document ID generation
        )
    
    return results

# Asyncio event loop to run the main function in a Jupyter notebook
results = await main()
print(f"Processed {len(results)} documents")

# # Asyncio event loop to run the main function in a script
# if __name__ == "__main__":
#     results = asyncio.run(main())
#     print(f"Processed {len(results)} documents")

Processing row 1 of 10
Result: run_id='8e749b6b-2127-47a1-9e9d-5e49c2da8c03' result={'resolver': {'number_of_nodes_to_resolve': 19, 'number_of_created_nodes': 3}}
Elapsed time: 15.58 seconds
Estimated time remaining: 140.18 seconds

Processing row 2 of 10
Result: run_id='7be02a04-abd5-4e44-8502-c9b1a6789925' result={'resolver': {'number_of_nodes_to_resolve': 55, 'number_of_created_nodes': 4}}
Elapsed time: 34.68 seconds
Estimated time remaining: 138.72 seconds

Processing row 3 of 10
Result: run_id='507facd2-caa5-4366-9bbb-2f76c1ea3606' result={'resolver': {'number_of_nodes_to_resolve': 59, 'number_of_created_nodes': 2}}
Elapsed time: 41.12 seconds
Estimated time remaining: 95.96 seconds

Processing row 4 of 10
Skipping row 4 due to empty text
Elapsed time: 41.12 seconds
Estimated time remaining: 61.69 seconds

Processing row 5 of 10
Result: run_id='edf86596-f0e2-441f-a5d2-476479c79d47' result={'resolver': {'number_of_nodes_to_resolve': 136, 'number_of_created_nodes': 9}}
Elapsed time:

Result (after running `MATCH p=()-->() RETURN p LIMIT 1000`):

![image.png](attachment:image.png)