# Knowledge Graph Generation with LLMs
![title](neo4jdogs.png)

## Overview
This notebook demonstrates how to create a knowledge graph from PDF documents using Large Language Models (specifically Azure OpenAI). The process includes:

1. Loading and preprocessing PDF documents
2. Using LLMs to extract entities and relationships
3. Constructing and storing a knowledge graph

The resulting graph is serialized for use in companion notebooks:
- [Knowledge graph with Neo4J (Cypher)](./graph-neo4j.ipynb)
- [Knowledge graph with Azure CosmosDB (Gremlin)](./graph-cosmosdb.ipynb)

# Load env variables and connect to Azure Openai
 

In [None]:
import os
import logging
from dotenv import load_dotenv
from langchain_openai import AzureChatOpenAI
from logging import StreamHandler

import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger()

# Load environment variables
load_dotenv()

# Validate required environment variables
required_vars = ["AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOYMENT_NAME", "AZURE_OPENAI_API_KEY"]
missing_vars = [var for var in required_vars if not os.getenv(var)]

if missing_vars:
    error_msg = f"Error: Missing required environment variables: {', '.join(missing_vars)}"
    logger.error(error_msg)
    raise EnvironmentError(error_msg)

# Configure LLM settings
REASONING_EFFORT = "medium"  # Options: "low", "medium", "high"
MAX_TIMEOUT = 3 * 60 * 1000  # 3 minutes in milliseconds

# Initialize LLM client
try:
    llm = AzureChatOpenAI(
        timeout=MAX_TIMEOUT,
        api_version="2025-02-01-preview",
        azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
        azure_deployment=os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME"),
        api_key=os.getenv("AZURE_OPENAI_API_KEY"),
        verbose=True,
        reasoning_effort=REASONING_EFFORT,
    )
    logger.info(f"Successfully initialized Azure OpenAI client with reasoning effort: {REASONING_EFFORT}")
except Exception as e:
    logger.error(f"Failed to initialize Azure OpenAI client: {str(e)}")
    raise

## Download test data

In [None]:
import os
from typing import List, Dict
from tqdm import tqdm
import urllib.request

def download_documents(urls: List[str], local_folder: str = "./data/") -> List[str]:
    """
    Download documents from URLs and save them to a local folder.
    Returns a list of downloaded document names.
    
    Args:
        urls: List of document URLs to download
        local_folder: Directory to save downloaded files
    
    Returns:
        List of downloaded file names
    """
    # Ensure the download directory exists
    os.makedirs(local_folder, exist_ok=True)
    
    doc_names = []
    
    for url in tqdm(urls, desc="Downloading documents"):
        try:
            # Extract file name from URL
            file_name = url.split("/")[-1]
            doc_names.append(file_name)
            file_path = os.path.join(local_folder, file_name)
            
            # Skip if file already exists
            if os.path.isfile(file_path):
                logger.info(f"File already exists: {file_name}")
                continue
                
            # Download file
            logger.info(f"Downloading {url}")
            urllib.request.urlretrieve(url, file_path)
            logger.info(f"Downloaded {file_name}")
        except Exception as e:
            logger.error(f"Failed to download {url}: {str(e)}")
    
    return doc_names

# Define document URLs
document_urls = [
    "https://www.marinhumane.org/wp-content/uploads/2017/06/Dog-Breed-Characteristics-Behavior.pdf"
    # Add more document URLs here as needed
]

# Download documents
local_folder = "./data/"
doc_names = download_documents(document_urls, local_folder)
logger.info(f"Downloaded {len(doc_names)} documents")

## Loading PDF Documents with Langchain
Use Langchain's built-in document loaders to process PDFs

In [None]:
import time
from langchain_community.document_loaders import PyMuPDFLoader, PyPDFLoader

# Create a dictionary to store Document objects for each PDF
document_objects = {}
all_documents = []

# Note that the document-loader is configured to 'single' mode.
# This works best if the documents are not too large (over 20 pages or so).

for doc_name in doc_names:
    print(f"Processing {doc_name}")
    start_time = time.time()
    
    # Use PyMuPDFLoader for better PDF text extraction
    # If PyMuPDF (fitz) is not installed, you can use PyPDFLoader instead
    try:
        loader = PyMuPDFLoader(local_folder + doc_name, mode="single")
        # Alternative if PyMuPDF is not available:
        # loader = PyPDFLoader(local_folder + doc_name)
        
        # Load the documents - this returns a list of Document objects
        documents = loader.load()
        document_objects[doc_name] = documents
        all_documents.extend(documents)
        
        # Display information about the loaded documents
        print(f"Loaded {len(documents)} document from {doc_name}")
        
    except Exception as e:
        print(f"Error loading {doc_name}: {e}")
    
    # Calculate processing time
    end_time = time.time()
    elapsed_time = end_time - start_time
    print(f"Processing completed in {elapsed_time:.6f} seconds\n")

print(f"Total Document objects created: {len(all_documents)}")

# Display content of the first document to verify extraction
if all_documents:
    print("\nFirst 200 characters of content:")
    print(all_documents[0].page_content[:200] + "...")

## Knowledge Graph Extraction

Using the LLMGraphTransformer from langchain_experimental to extract entities and relationships from the documents.
We'll define allowed node types and configure the extraction process.

In [None]:
from langchain_experimental.graph_transformers import LLMGraphTransformer
import time
from tqdm.notebook import tqdm

# Create the transformer with our custom 🐶 configuration
llm_transformer = LLMGraphTransformer(
        llm=llm,
        allowed_nodes= ["Breed", "BreedingGroup", "Characteristics", "Trait", "Origin"],
        allowed_relationships= ["HAS_CHARACTERISTIC", "BELONGS_TO", "ORIGINATED_FROM"],        
    )


In [None]:
def process_documents_to_graph(documents, batch_size=5):
    """
    Process documents into graph format with progress tracking and error handling.
    
    Args:
        documents: List of Document objects to process
        batch_size: Number of documents to process in each batch
        
    Returns:
        List of graph documents
    """
    graph_documents = []
    start_time = time.time()
    
    for i in tqdm(range(0, len(documents), batch_size), desc="Processing document batches"):
        batch = documents[i:i+batch_size]
        try:
            batch_results = llm_transformer.convert_to_graph_documents(batch)
            graph_documents.extend(batch_results)
            logger.info(f"Processed batch {i//batch_size + 1} with {len(batch)} documents")
        except Exception as e:
            logger.error(f"Error processing batch {i//batch_size + 1}: {str(e)}")
            # Continue with next batch instead of failing completely
            continue
    
    elapsed_time = time.time() - start_time
    logger.info(f"Processing completed in {elapsed_time:.2f} seconds")
    
    return graph_documents

# Process all documents to graph format
graph_documents = process_documents_to_graph(all_documents)

# Display summary statistics
if graph_documents:
    total_nodes = sum(len(doc.nodes) for doc in graph_documents)
    total_rels = sum(len(doc.relationships) for doc in graph_documents)
    print(f"Extracted {total_nodes} nodes and {total_rels} relationships from {len(graph_documents)} documents")
    
    # Display sample from first document
    print("\nSample from first document:")
    print(f"Nodes: {graph_documents[0].nodes[:3]}" if graph_documents[0].nodes else "No nodes found")
    print(f"Relationships: {graph_documents[0].relationships[:3]}" if graph_documents[0].relationships else "No relationships found")

## Store the graph documents
Save the extracted graph data to pickle file for later use (in the other notebooks)

In [None]:
# Let's pickle the graph so we don't have to redo this all the time 🥒
import pickle

with open('./data/graph_docs.pkl','wb') as f:
    pickle.dump(graph_documents, f)

print(f"Saved {len(graph_documents)} graph documents to './data/graph_docs.pkl'")

## Visualize and Analyze the Knowledge Graph

Let's create a simple visualization of our knowledge graph using NetworkX and matplotlib.
(It looks nicer in Neo4j ...)

In [None]:
import networkx as nx
import matplotlib.pyplot as plt
from collections import Counter

def build_networkx_graph(graph_docs):
    """
    Convert graph documents to a NetworkX graph for visualization and analysis.
    
    Args:
        graph_docs: List of graph documents with nodes and relationships
        
    Returns:
        NetworkX graph object
    """
    G = nx.MultiDiGraph()
    
    # Track unique nodes to avoid duplicates
    added_nodes = set()
    
    # Add all nodes first
    for doc in graph_docs:
        for node in doc.nodes:
            node_id = f"{node.id}"
            if node_id not in added_nodes:
                G.add_node(
                    node_id,
                    label=node.id,
                    type=node.type,
                    properties=node.properties
                )
                added_nodes.add(node_id)
    
    # Then add all relationships
    for doc in graph_docs:
        for rel in doc.relationships:
            G.add_edge(
                f"{rel.source}",
                f"{rel.target}",
                type=rel.type,
                properties=rel.properties
            )
    
    return G

# Build the graph
try:
    G = build_networkx_graph(graph_documents)
    
    # Print graph statistics
    print(f"Graph Statistics:")
    print(f"- Nodes: {G.number_of_nodes()}")
    print(f"- Edges: {G.number_of_edges()}")
    
    # Count node types - safely handling missing 'type' attribute
    node_types = [data.get('type', 'Unknown') for _, data in G.nodes(data=True)]
    type_counts = Counter(node_types)
    print("\nNode type distribution:")
    for node_type, count in type_counts.items():
        print(f"- {node_type}: {count} nodes")
    
    # Count relationship types - safely handling missing 'type' attribute
    rel_types = [data.get('type', 'Unknown') for _, _, data in G.edges(data=True)]
    rel_counts = Counter(rel_types)
    print("\nRelationship type distribution:")
    for rel_type, count in rel_counts.items():
        print(f"- {rel_type}: {count} relationships")
    
    # Visualize the graph
    plt.figure(figsize=(12, 10))
    
    # Use different colors for different node types
    node_colors = []
    color_map = {
        'Breed': 'skyblue',
        'BreedingGroup': 'lightgreen',
        'Characteristics': 'salmon',
        'Trait': 'orange',
        'Origin': 'purple',
        'Unknown': 'gray'
    }
    
    for _, node_data in G.nodes(data=True):
        node_type = node_data.get('type', 'Unknown')
        node_colors.append(color_map.get(node_type, 'gray'))
    
    # Position nodes using spring layout
    pos = nx.spring_layout(G, seed=42)  # For reproducibility
    
    # Draw nodes and edges
    nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=700, alpha=0.8)
    nx.draw_networkx_edges(G, pos, width=1.0, alpha=0.5, arrowsize=20)
    
    # Add labels with smaller font size
    labels = {node: G.nodes[node].get('label', node) for node in G.nodes()}
    nx.draw_networkx_labels(G, pos, labels=labels, font_size=8)
    
    # Add a legend
    legend_patches = [plt.Line2D([0], [0], marker='o', color='w', 
                                label=node_type,
                                markerfacecolor=color, markersize=10)
                      for node_type, color in color_map.items() if node_type in type_counts]
    plt.legend(handles=legend_patches, loc='upper right')
    
    plt.title('Knowledge Graph Visualization')
    plt.axis('off')
    plt.tight_layout()
    plt.savefig('./data/knowledge_graph.png', dpi=300, bbox_inches='tight')
    plt.show()
    
except Exception as e:
    logger.error(f"Error visualizing graph: {str(e)}")
    # Print more detailed error information for debugging
    import traceback
    traceback.print_exc()
