# Hybrid Knowledge Graph Extraction for Antiquities Trafficking

This notebook implements a hybrid pipeline combining:
- **spaCy** for fast entity extraction
- **Gemini LLM** for coreference resolution and entity refinement
- **Knowledge Graph** construction and visualization

## Architecture
1. Stage 1: Fast NLP extraction (spaCy) → 259 entities in ~1s
2. Stage 2: LLM refinement (Gemini) → Coreference + Canonicalization
3. Stage 3: Knowledge Graph construction
4. Stage 4: Visualization and export

## 📦 Installation

In [None]:
# Install required packages
!pip install -q spacy google-generativeai networkx matplotlib pyvis
!python -m spacy download en_core_web_trf

## 🔑 Configuration

In [None]:
import spacy
import google.generativeai as genai
import json
import networkx as nx
import matplotlib.pyplot as plt
from collections import defaultdict

# Set your Gemini API key
# Get key from: https://makersuite.google.com/app/apikey
GEMINI_API_KEY = "YOUR_API_KEY_HERE"  # Replace with your key

# Configure Gemini
genai.configure(api_key=GEMINI_API_KEY)
model = genai.GenerativeModel("gemini-2.0-flash-exp")

# Load spaCy model
print("Loading spaCy model...")
nlp = spacy.load("en_core_web_trf")
print("✅ Setup complete!")

## 🔍 Stage 1: Fast Entity Extraction (spaCy)

In [None]:
def extract_entities_fast(text):
    """Stage 1: Fast entity extraction with spaCy"""
    doc = nlp(text)
    
    entities = []
    for ent in doc.ents:
        entities.append({
            "text": ent.text,
            "label": ent.label_,
            "start": ent.start_char,
            "end": ent.end_char,
        })
    
    return {
        "entities": entities,
        "text": text
    }

## 🤖 Stage 2: LLM Refinement Functions

### Option A: Simple (Fast, Coreference Only)

In [None]:
def llm_refine_entities_simple(raw_entities, text):
    """Simplified: Just coreference resolution"""
    
    entities = raw_entities["entities"]
    
    # Focus on PERSON and ORG only
    key_entities = [e for e in entities if e['label'] in ['PERSON', 'ORG', 'GPE']]
    
    print(f"  Focusing on {len(key_entities)} PERSON/ORG entities...")
    
    # Group by type
    entity_list = {}
    for e in key_entities:
        label = e['label']
        if label not in entity_list:
            entity_list[label] = []
        entity_list[label].append(e['text'])
    
    # Deduplicate
    for label in entity_list:
        entity_list[label] = list(set(entity_list[label]))
    
    entity_summary = "\n\n".join([
        f"{label}:\n" + "\n".join([f"  - {text}" for text in texts[:50]])
        for label, texts in entity_list.items()
    ])
    
    prompt = f"""Identify which entity mentions refer to the same real-world entity in this antiquities trafficking document.

DOCUMENT START:
{text[:2000]}...

ENTITIES FOUND:
{entity_summary}

Example output format:
{{
  "giacomo_medici": {{
    "full_name": "Giacomo Medici",
    "mentions": ["Giacomo Medici", "Medici"],
    "type": "PERSON",
    "role": "dealer"
  }},
  "j_paul_getty_museum": {{
    "full_name": "J. Paul Getty Museum",
    "mentions": ["J. Paul Getty Museum", "Getty Museum"],
    "type": "ORGANIZATION",
    "role": "museum"
  }}
}}

Return ONLY the JSON object."""
    
    try:
        response = model.generate_content(
            prompt,
            generation_config=genai.types.GenerationConfig(
                temperature=0.1,
                max_output_tokens=4096,
            )
        )
        
        print(f"  Response length: {len(response.text)} chars")
        
        # Parse JSON
        clusters = json.loads(response.text)
        
        if not isinstance(clusters, dict):
            print(f"  ERROR: Expected dict, got {type(clusters)}")
            return {"refined_entities": [], "entity_clusters": {}, "events": []}
        
        # Convert to expected format
        simplified_clusters = {}
        refined_entities = []
        
        for canonical_id, data in clusters.items():
            if not isinstance(data, dict):
                continue
                
            mentions = data.get('mentions', [])
            if not mentions:
                continue
                
            simplified_clusters[canonical_id] = mentions
            
            for mention in mentions:
                original = next((e for e in entities if e['text'] == mention), None)
                if original:
                    refined_entities.append({
                        "original_text": mention,
                        "canonical_id": canonical_id,
                        "entity_class": data.get('type', 'UNKNOWN'),
                        "attributes": {
                            "full_name": data.get('full_name', ''),
                            "role": data.get('role', '')
                        },
                        "start": original['start'],
                        "end": original['end']
                    })
        
        print(f"  Successfully processed {len(simplified_clusters)} clusters")
        
        return {
            "refined_entities": refined_entities,
            "entity_clusters": simplified_clusters,
            "events": []
        }
        
    except Exception as e:
        print(f"  Error: {e}")
        return {"refined_entities": [], "entity_clusters": {}, "events": []}

### Option B: Batched (Comprehensive)

In [None]:
def llm_refine_entities_batched(raw_entities, text, batch_size=30):
    """Process entities in batches"""
    
    all_refined = []
    all_clusters = {}
    all_events = []
    
    entities = raw_entities["entities"]
    num_batches = (len(entities) + batch_size - 1) // batch_size
    
    print(f"  Processing {len(entities)} entities in {num_batches} batches...")
    
    for i in range(0, len(entities), batch_size):
        batch = entities[i:i+batch_size]
        batch_num = i // batch_size + 1
        
        print(f"  Batch {batch_num}/{num_batches}...", end=" ")
        
        entity_summary = "\n".join([f"- {e['label']}: '{e['text']}'" for e in batch])
        
        # Get context
        min_start = min(e['start'] for e in batch)
        max_end = max(e['end'] for e in batch)
        context_start = max(0, min_start - 200)
        context_end = min(len(text), max_end + 200)
        context = text[context_start:context_end]
        
        prompt = f"""Analyze these entities from an antiquities trafficking document.

CONTEXT:
{context}

ENTITIES:
{entity_summary}

Return ONLY valid JSON:
{{
  "refined_entities": [
    {{"original_text": "...", "canonical_id": "...", "entity_class": "PERSON|ORG|ARTIFACT", "attributes": {{}}, "start": 0, "end": 0}}
  ],
  "entity_clusters": {{"canonical_id": ["mention1", "mention2"]}},
  "events": []
}}"""
        
        try:
            response = model.generate_content(
                prompt,
                generation_config=genai.types.GenerationConfig(
                    temperature=0.1,
                    max_output_tokens=8192,
                    response_mime_type="application/json",
                )
            )
            
            result = json.loads(response.text)
            all_refined.extend(result.get('refined_entities', []))
            all_clusters.update(result.get('entity_clusters', {}))
            all_events.extend(result.get('events', []))
            print("✓")
            
        except Exception as e:
            print(f"Error: {e}")
            continue
    
    return {
        "refined_entities": all_refined,
        "entity_clusters": all_clusters,
        "events": all_events
    }

### Option C: spaCy Only (No LLM, Fastest)

In [None]:
def create_kg_from_spacy_only(raw_entities, text):
    """Create KG directly from spaCy without LLM refinement"""
    
    entities = raw_entities["entities"]
    clusters = {}
    
    for e in entities:
        if e['label'] == 'PERSON':
            # Use last word as canonical ID
            words = e['text'].split()
            canonical = words[-1].lower().replace('.', '') if words else e['text'].lower()
            
            if canonical not in clusters:
                clusters[canonical] = []
            clusters[canonical].append(e['text'])
        elif e['label'] in ['ORG', 'GPE', 'FAC']:
            canonical = e['text'].lower().replace(' ', '_').replace('.', '')
            if canonical not in clusters:
                clusters[canonical] = []
            clusters[canonical].append(e['text'])
    
    # Build refined entities
    refined_entities = []
    for canonical_id, mentions in clusters.items():
        for mention in mentions:
            original = next((e for e in entities if e['text'] == mention), None)
            if original:
                refined_entities.append({
                    "original_text": mention,
                    "canonical_id": canonical_id,
                    "entity_class": original['label'],
                    "attributes": {"full_name": mention},
                    "start": original['start'],
                    "end": original['end']
                })
    
    return {
        "refined_entities": refined_entities,
        "entity_clusters": clusters,
        "events": []
    }

## 🔗 Stage 3: Knowledge Graph Construction

In [None]:
def build_knowledge_graph(raw_entities, refined_entities, entity_clusters, events):
    """Construct knowledge graph from hybrid extraction"""
    
    def get_position(mention_text, raw_entities):
        for entity in raw_entities:
            if entity['text'] == mention_text:
                return (entity['start'], entity['end'])
        return None
    
    nodes = {}
    edges = []
    
    # Create nodes
    for canonical_id, mentions in entity_clusters.items():
        entity_data = next(
            (e for e in refined_entities if e.get("canonical_id") == canonical_id),
            None
        )
        
        if entity_data is None:
            first_mention = mentions[0] if mentions else canonical_id
            entity_data = {
                "canonical_id": canonical_id,
                "entity_class": "UNKNOWN",
                "attributes": {"full_name": first_mention}
            }
        
        nodes[canonical_id] = {
            "id": canonical_id,
            "type": entity_data.get("entity_class", "UNKNOWN"),
            "label": entity_data.get("attributes", {}).get("full_name", canonical_id),
            "mentions": mentions,
            "attributes": entity_data.get("attributes", {}),
            "source_positions": [
                {"mention": m, "position": get_position(m, raw_entities)}
                for m in mentions
            ]
        }
    
    # Create edges from events
    for event in events:
        event_entities = event.get("entities", [])
        event_id = event.get("event_id", "unknown")
        event_type = event.get("event_type", "related_to")
        
        for i, e1 in enumerate(event_entities):
            for e2 in event_entities[i+1:]:
                edges.append({
                    "source": e1,
                    "target": e2,
                    "relation": event_type,
                    "event_id": event_id
                })
    
    # If no events, create co-occurrence edges
    if not edges:
        node_ids = list(nodes.keys())
        for i, n1 in enumerate(node_ids):
            for n2 in node_ids[i+1:min(i+6, len(node_ids))]:
                edges.append({
                    "source": n1,
                    "target": n2,
                    "relation": "co_occurs_with",
                    "event_id": "implicit"
                })
    
    return {
        "nodes": nodes,
        "edges": edges,
        "summary": {
            "num_nodes": len(nodes),
            "num_edges": len(edges),
            "node_types": list(set(n["type"] for n in nodes.values()))
        }
    }

## 📊 Visualization Functions

In [None]:
def display_kg(kg, max_nodes=20):
    """Pretty print the knowledge graph"""
    
    print("\n" + "="*60)
    print("KNOWLEDGE GRAPH")
    print("="*60)
    
    print(f"\n📊 SUMMARY:")
    print(f"   Nodes: {kg['summary']['num_nodes']}")
    print(f"   Edges: {kg['summary']['num_edges']}")
    print(f"   Types: {', '.join(kg['summary']['node_types'])}")
    
    print(f"\n👤 NODES (showing first {max_nodes}):")
    for i, (node_id, node_data) in enumerate(list(kg['nodes'].items())[:max_nodes]):
        print(f"\n   {i+1}. {node_data['label']} ({node_data['type']})")
        print(f"      ID: {node_id}")
        print(f"      Mentions: {', '.join(node_data['mentions'][:5])}")
        if node_data['attributes']:
            print(f"      Attributes: {node_data['attributes']}")
    
    print(f"\n🔗 EDGES (showing first 20):")
    for i, edge in enumerate(kg['edges'][:20]):
        source_label = kg['nodes'][edge['source']]['label']
        target_label = kg['nodes'][edge['target']]['label']
        print(f"   {i+1}. {source_label} --[{edge['relation']}]--> {target_label}")
    
    if len(kg['edges']) > 20:
        print(f"   ... and {len(kg['edges']) - 20} more edges")


def visualize_kg_simple(kg, output_file="kg_viz.png"):
    """Create simple matplotlib visualization"""
    
    G = nx.Graph()
    
    # Add nodes
    for node_id, node_data in kg['nodes'].items():
        G.add_node(node_id, label=node_data['label'], type=node_data['type'])
    
    # Add edges
    for edge in kg['edges']:
        if edge['source'] in G.nodes and edge['target'] in G.nodes:
            G.add_edge(edge['source'], edge['target'], relation=edge['relation'])
    
    # Create layout
    plt.figure(figsize=(16, 12))
    pos = nx.spring_layout(G, k=0.5, iterations=50)
    
    # Color nodes by type
    color_map = {
        "PERSON": "#FF6B6B",
        "ORG": "#4ECDC4", 
        "ORGANIZATION": "#4ECDC4",
        "ARTIFACT": "#FFE66D",
        "GPE": "#95E1D3",
        "LOCATION": "#95E1D3",
        "UNKNOWN": "#CCCCCC"
    }
    
    node_colors = [color_map.get(kg['nodes'][node]['type'], "#CCCCCC") for node in G.nodes()]
    
    # Draw
    nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=1000, alpha=0.9)
    nx.draw_networkx_edges(G, pos, alpha=0.3, width=2)
    
    # Labels
    labels = {node: kg['nodes'][node]['label'][:20] for node in G.nodes()}
    nx.draw_networkx_labels(G, pos, labels, font_size=8)
    
    plt.title(f"Knowledge Graph: {kg['summary']['num_nodes']} nodes, {kg['summary']['num_edges']} edges", 
              fontsize=14, fontweight='bold')
    plt.axis('off')
    plt.tight_layout()
    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    print(f"\n✅ Visualization saved to {output_file}")
    plt.show()


def save_kg_json(kg, filename="knowledge_graph.json"):
    """Save knowledge graph as JSON"""
    
    kg_serializable = {
        "nodes": [{"id": node_id, **node_data} for node_id, node_data in kg['nodes'].items()],
        "edges": kg['edges'],
        "summary": kg['summary']
    }
    
    with open(filename, 'w', encoding='utf-8') as f:
        json.dump(kg_serializable, f, indent=2, ensure_ascii=False)
    
    print(f"\n✅ Knowledge graph saved to {filename}")
    return filename

## 🚀 Main Pipeline

In [None]:
def hybrid_kg_extraction(document_text, mode="spacy_only"):
    """Full hybrid extraction pipeline
    
    Args:
        mode: "spacy_only" (fastest), "simple" (LLM coreference), "batched" (full LLM)
    """
    
    print("Stage 1: spaCy extraction...")
    raw = extract_entities_fast(document_text)
    print(f"  Found {len(raw['entities'])} raw entities")
    
    print(f"\nStage 2: LLM refinement (mode={mode})...")
    
    if mode == "spacy_only":
        refined = create_kg_from_spacy_only(raw, document_text)
    elif mode == "simple":
        refined = llm_refine_entities_simple(raw, document_text)
    elif mode == "batched":
        refined = llm_refine_entities_batched(raw, document_text, batch_size=30)
    else:
        raise ValueError(f"Unknown mode: {mode}")
    
    print(f"  Refined to {len(refined['refined_entities'])} entities")
    print(f"  Identified {len(refined['entity_clusters'])} entity clusters")
    print(f"  Found {len(refined['events'])} events")
    
    return {
        "raw": raw,
        "refined": refined
    }

## 📝 Example Usage

In [None]:
# Sample document text
document_text = """Giacomo Medici is an Italian antiquities dealer who was convicted in 2005 of receiving stolen goods, illegal export of goods, and conspiracy to traffic.

Medici started dealing in antiquities in Rome during the 1960s. In July 1967, he was convicted in Italy of receiving looted artefacts, though in the same year he met and became an important supplier of antiquities to US dealer Robert Hecht. In 1968, Medici opened the gallery Antiquaria Romana in Rome and began to explore business opportunities in Switzerland.

In 1978, he closed his Rome gallery, and entered into partnership with Geneva resident Christian Boursaud, who started consigning material supplied by Medici for sale at Sotheby's London. Together, they opened Hydra Gallery in Geneva in 1983.

In October 1985, the Hydra Gallery sold fragments of the Onesimos kylix to the J. Paul Getty Museum for $100,000, providing a false provenance by way of the fictitious Zbinden collection. The Getty returned the kylix to Italy in 1999.

On 13 September 1995, in concert with Swiss police, they raided Medici's storage space in the Geneva Freeport. In January 1997, Medici was arrested in Rome.

Medici was charged with receiving stolen goods, illegal export of goods, and conspiracy to traffic. On 12 May 2005, he was found guilty of all charges. He was sentenced to ten years in prison and received a €10 million fine."""

# Run the pipeline
result = hybrid_kg_extraction(document_text, mode="spacy_only")

## 🏗️ Build Knowledge Graph

In [None]:
# Build the knowledge graph
kg = build_knowledge_graph(
    result['raw']['entities'],
    result['refined']['refined_entities'],
    result['refined']['entity_clusters'],
    result['refined']['events']
)

# Display summary
display_kg(kg, max_nodes=15)

## 📊 Visualize and Save

In [None]:
# Visualize
visualize_kg_simple(kg, "antiquities_kg.png")

# Save as JSON
save_kg_json(kg, "antiquities_kg.json")

## 🔍 Query the Graph

In [None]:
def query_kg(kg, query_type, **kwargs):
    """Simple query interface"""
    
    if query_type == "find_person":
        name = kwargs.get('name', '').lower()
        return [(nid, n) for nid, n in kg['nodes'].items()
                if n['type'] == 'PERSON' and name in nid.lower()]
    
    elif query_type == "connections":
        node_id = kwargs.get('node_id')
        connections = [(e['target'], e['relation']) for e in kg['edges'] 
                      if e['source'] == node_id]
        connections.extend([(e['source'], e['relation']) for e in kg['edges']
                           if e['target'] == node_id])
        return connections
    
    elif query_type == "by_type":
        entity_type = kwargs.get('type')
        return [(nid, n) for nid, n in kg['nodes'].items()
                if n['type'] == entity_type]

# Example queries
print("\n🔍 QUERIES")
print("\nAll people:")
people = query_kg(kg, "by_type", type="PERSON")
for pid, person in people[:10]:
    print(f"  - {person['label']}")

print("\nAll organizations:")
orgs = query_kg(kg, "by_type", type="ORG")
for oid, org in orgs[:10]:
    print(f"  - {org['label']}")

## 📤 Use Your Own Document

In [None]:
# Load your document
# Option 1: From file
# with open('your_document.txt', 'r', encoding='utf-8') as f:
#     your_document = f.read()

# Option 2: Paste directly
your_document = """Paste your antiquities trafficking document here..."""

# Run extraction
# Try "spacy_only" first (fastest), then "simple" or "batched" if you need better coreference
result = hybrid_kg_extraction(your_document, mode="spacy_only")

# Build graph
kg = build_knowledge_graph(
    result['raw']['entities'],
    result['refined']['refined_entities'],
    result['refined']['entity_clusters'],
    result['refined']['events']
)

# Display and save
display_kg(kg)
visualize_kg_simple(kg, "my_kg.png")
save_kg_json(kg, "my_kg.json")

## 📋 Summary

### Three Modes Available:

1. **`spacy_only`** (Recommended Start)
   - Fastest (1-2 seconds)
   - No LLM calls
   - Simple last-name based coreference
   - Good for quick exploration

2. **`simple`** (Better Coreference)
   - Medium speed (5-10 seconds)
   - 1 LLM call for coreference resolution
   - Better entity linking
   - Good for production

3. **`batched`** (Full Feature)
   - Slower (30-60 seconds)
   - Multiple LLM calls
   - Full domain classification + events
   - Best quality

### Next Steps:
- Scale to multiple documents
- Fine-tune spaCy NER model on antiquities domain
- Export to Neo4j for graph database queries
- Add relationship extraction rules
- Create interactive web dashboard