<a href="https://colab.research.google.com/github/simranbains9810/temporal_agent/blob/main/temporal_agent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Temporal Agent for Financial Transcripts

**Author:** Simran Bains

---

## Overview
This project develops a **time-aware RAG system** for financial earnings call transcripts, where earlier statements are marked as expired once newer updates appear. It combines semantic chunking, fact extraction with timestamps, entity resolution, temporal invalidation, and knowledge graph construction. The result is a dynamic knowledge base that can answer **time-sensitive questions** about how a company’s outlook changes over time.

**Code repository:** https://github.com/simranbains9810/temporal_agent

---

## Table of Contents
1. [Notebook Setup](#notebook-setup)  
2. [Pre-processing and Analyzing our Dynamic Data](#pre-processing-and-analysing-our-dynamic-data)  
3. [Percentile Based Chunking](#percentile-based-chunking)  
4. [Extracting Atomic Facts with a Statement Agent](#extracting-atomic-facts-with-a-statement-agent)  
5. [Pinpointing Time with a Validation Check Agent](#pinpointing-time-with-a-validation-check-agent)  
6. [Structuring Facts into Triplets](#structuring-facts-into-triplets)
7. [Assembling the Temporal Event](#assembling-the-temporal-event)
8. [Automating the Pipeline with LangGraph](#automating-the-pipeline-with-LangGraph)
9. [Cleaning Our Data with Entity Resolution](#cleaning-our-data-with-entity-resolution)
10. [Making Our Knowledge Dynamic with an Invalidation Agent](#making-our-knowledge-dynamic-with-an-invalidation-agent)
11. [Assembling the Temporal Knowledge Graph](#assembling-the-temporal-knowledge-graph)
12. [Building and Testing A Multi-Step Retrieval Agent](#building-and-testing-a-multi-step-retrieval-agent)

---

## Notebook Setup
- Core: `langchain-community`, `datasets==2.19.0`, `langchain-experimental`, `langgraph`
- Embeddings/LLM: `langchain-nebius`, `openai` (indirect), `tiktoken`
- Data/Utils: `pandas`, `numpy`, `pyarrow`, `python-dateutil`, `rapidfuzz`, `networkx`, `matplotlib`, `scipy`
- (Colab will already include many of these.)

### Install/Upgrade Dependencies
```bash
# Base installs (pin datasets to 2.19.0; langchain-* at current project versions)
pip install -q "datasets==2.19.0" langchain-community langchain-experimental langgraph \
               langchain-nebius rapidfuzz networkx matplotlib scipy

# Optional: if you hit storage/FS conflicts, align fsspec with datasets 2.19.0 constraints
pip install -q "fsspec==2024.3.1" || true


##Pre-processing and Analyzing our Dynamic Data

In the pre-processing stage we use HuggingFace to extract the ilh-ibm/earnings call dataset, which gives us 188 transcripts from about ten big tech companies. Each transcript comes with a company name and date, so we already have a natural time anchor. The transcripts are pretty hefty - around 9,000 words on average. The first step is to check how the data is distributed across companies by running some basic stats - I had pulled out quarter references like "Q2 2016" from the text so it could be identified which time period each transcript was describing.

In [None]:
pip install langchain-community "datasets==2.19.0"

Collecting langchain-community
  Downloading langchain_community-0.3.27-py3-none-any.whl.metadata (2.9 kB)
Collecting datasets==2.19.0
  Downloading datasets-2.19.0-py3-none-any.whl.metadata (19 kB)
Collecting pyarrow-hotfix (from datasets==2.19.0)
  Downloading pyarrow_hotfix-0.7-py3-none-any.whl.metadata (3.6 kB)
Collecting fsspec<=2024.3.1,>=2023.1.0 (from fsspec[http]<=2024.3.1,>=2023.1.0->datasets==2.19.0)
  Downloading fsspec-2024.3.1-py3-none-any.whl.metadata (6.8 kB)
Collecting dataclasses-json<0.7,>=0.5.7 (from langchain-community)
  Downloading dataclasses_json-0.6.7-py3-none-any.whl.metadata (25 kB)
Collecting pydantic-settings<3.0.0,>=2.4.0 (from langchain-community)
  Downloading pydantic_settings-2.10.1-py3-none-any.whl.metadata (3.4 kB)
Collecting httpx-sse<1.0.0,>=0.4.0 (from langchain-community)
  Downloading httpx_sse-0.4.1-py3-none-any.whl.metadata (9.4 kB)
Collecting marshmallow<4.0.0,>=3.18.0 (from dataclasses-json<0.7,>=0.5.7->langchain-community)
  Downloading ma

In [None]:
# Import loader for Hugging Face datasets
from langchain_community.document_loaders import HuggingFaceDatasetLoader

# Dataset configuration
hf_dataset_name = "jlh-ibm/earnings_call"  # HF dataset name
subset_name = "transcripts"                # Dataset subset to load

# Create the loader (defaults to 'train' split)
loader = HuggingFaceDatasetLoader(
    path=hf_dataset_name,
    name=subset_name,
    page_content_column="transcript"  # Column containing the main text
)

# This is the key step. The loader processes the dataset and returns a list of LangChain Document objects.
documents = loader.load()



In [None]:
# Let's inspect the result to see the difference
print(f"Loaded {len(documents)} documents.")

Loaded 188 documents.


In [None]:
# Count how many documents each company has
company_counts = {}

# Loop over all loaded documents
for doc in documents:
    company = doc.metadata.get("company")  # Extract company from metadata
    if company:
        company_counts[company] = company_counts.get(company, 0) + 1

# Display the counts
print("Total company counts:")
for company, count in company_counts.items():
    print(f" - {company}: {count}")

Total company counts:
 - AMD: 19
 - AAPL: 19
 - INTC: 19
 - MU: 17
 - GOOGL: 19
 - ASML: 19
 - CSCO: 19
 - NVDA: 19
 - AMZN: 19
 - MSFT: 19


In [None]:
# Print metadata for two sample documents (index 0 and 33)
print("Metadata for document[0]:")
print(documents[0].metadata)

print("\nMetadata for document[33]:")
print(documents[33].metadata)

Metadata for document[0]:
{'company': 'AMD', 'date': datetime.date(2016, 7, 21)}

Metadata for document[33]:
{'company': 'AMZN', 'date': datetime.date(2019, 10, 24)}


In [None]:
# Print the first 200 characters of the first document's content
first_doc = documents[0]
print(first_doc.page_content[:200])



In [None]:
# Calculate the average number of words per document
total_words = sum(len(doc.page_content.split()) for doc in documents)
average_words = total_words / len(documents) if documents else 0

print(f"Average number of words in documents: {average_words:.2f}")

Average number of words in documents: 8797.19


In [None]:
import re
from datetime import datetime

# Helper function to extract a quarter string (e.g., "Q1 2023") from text
def find_quarter(text: str) -> str | None:
    """Return the first quarter-year match found in the text, or None if absent."""
    # Match pattern: 'Q' followed by 1 digit, a space, and a 4-digit year
    match = re.findall(r"Q\d\s\d{4}", text)
    return match[0] if match else None

# Test on the first document
quarter = find_quarter(documents[0].page_content)
print(f"Extracted Quarter for the first document: {quarter}")

Extracted Quarter for the first document: Q2 2016


##Percentile Based Chunking

This section explains how percentile-based semantic chunking is used to split long financial transcripts into smaller, meaningful segments without losing context. Instead of cutting at fixed lengths or punctuation, each sentence is embedded with a model (Qwen3–8B via Nebius), and semantic distances between consecutive sentences are calculated. The process works as follows:
Split text into sentences via regex.

*   Split text into sentences via regex.
*   Embed each sentence into a vector.
*   Compute semantic distance between consecutive vectors.
*   Determine the chosen percentile (e.g., 95th) of distances.
*   Mark distances ≥ threshold as breakpoints.
*   Group sentences between breakpoints into chunks, enforcing minimum size and optional overlap.

Using this method, 188 long documents (≈8K words each) are transformed into 3,556 smaller chunks (≈445 words each) while retaining metadata such as company, date, and quarter for later retrieval.

In [None]:
pip install langchain-nebius



In [None]:
import os

In [None]:
from langchain_nebius import NebiusEmbeddings

# Set Nebius API key (⚠️ Avoid hardcoding secrets in production code)
os.environ["NEBIUS_API_KEY"] = "INSERT NEBIUS KEY"

# 1. Initialize Nebius embedding model
embeddings = NebiusEmbeddings(model="Qwen/Qwen3-Embedding-8B")

In [None]:
pip install langchain-experimental



In [None]:
from langchain_experimental.text_splitter import SemanticChunker

# Create a semantic chunker using percentile thresholding
langchain_semantic_chunker = SemanticChunker(
    embeddings,
    breakpoint_threshold_type="percentile",  # Use percentile-based splitting
    breakpoint_threshold_amount=95           # split at 95th percentile
)

In [None]:
from tqdm import tqdm

In [None]:
# Store the new, smaller chunk documents
chunked_documents_lc = []

# Printing total number of docs (188) We already know that
print(f"Processing {len(documents)} documents using LangChain's SemanticChunker...")

# Chunk each transcript document
for doc in tqdm(documents, desc="Chunking Transcripts with LangChain"):
    # Extract quarter info and copy existing metadata
    quarter = find_quarter(doc.page_content)
    parent_metadata = doc.metadata.copy()
    parent_metadata["quarter"] = quarter

    # Perform semantic chunking (returns Document objects with metadata attached)
    chunks = langchain_semantic_chunker.create_documents(
        [doc.page_content],
        metadatas=[parent_metadata]
    )

    # Collect all chunks
    chunked_documents_lc.extend(chunks)

Processing 188 documents using LangChain's SemanticChunker...


Chunking Transcripts with LangChain: 100%|██████████| 188/188 [10:09:04<00:00, 194.39s/it]


In [None]:
# Analyze the results of the LangChain chunking process
original_doc_count = len(documents)
chunked_doc_count = len(chunked_documents_lc)

print(f"Original number of documents (transcripts): {original_doc_count}")
print(f"Number of new documents (chunks): {chunked_doc_count}")
print(f"Average chunks per transcript: {chunked_doc_count / original_doc_count:.2f}")


Original number of documents (transcripts): 188
Number of new documents (chunks): 4275
Average chunks per transcript: 22.74


In [None]:
# Inspect the 11th chunk (index 10)
sample_chunk = chunked_documents_lc[10]
print("Sample Chunk Content (first 30 chars):")
print(sample_chunk.page_content[:30] + "...")

print("\nSample Chunk Metadata:")
print(sample_chunk.metadata)

# Calculate average word count per chunk
total_chunk_words = sum(len(doc.page_content.split()) for doc in chunked_documents_lc)
average_chunk_words = total_chunk_words / chunked_doc_count if chunked_documents_lc else 0

print(f"\nAverage number of words per chunk: {average_chunk_words:.2f}")

Sample Chunk Content (first 30 chars):
No, that's a fair question, Ma...

Sample Chunk Metadata:
{'company': 'AMD', 'date': datetime.date(2016, 7, 21), 'quarter': 'Q2 2016'}

Average number of words per chunk: 386.87


##Extracting Atomic Facts with a Statement Agent

Now we have obtained the sentences in sematic chunks, we distil them into atomic statements which are self-contained claims that can stand on their own. Each statement is labelled in two ways (1) by type of claim - fact, opinion or prediction and (2) temporal nature - static (anchored to a point in time), dynamic (can change over time) or atemporal (always true). To make the model outputs consistently structures, we define a schema with pydantic models which essentially acts like a contract: the LLM must return JSON object with each statement plus its labels. By tagging these statements, the Al agent can separate what actually happened from what management thinks or expects to happen.

In [None]:
from enum import Enum

# Enum for temporal labels describing time sensitivity
class TemporalType(str, Enum):
    ATEMPORAL = "ATEMPORAL"  # Facts that are always true (e.g., "Earth is a planet")
    STATIC = "STATIC"        # Facts about a single point in time (e.g., "Product X launched on Jan 1st")
    DYNAMIC = "DYNAMIC"      # Facts describing an ongoing state (e.g., "Lisa Su is the CEO")

In [None]:
# Enum for statement labels classifying statement nature
class StatementType(str, Enum):
    FACT = "FACT"            # An objective, verifiable claim
    OPINION = "OPINION"      # A subjective belief or judgment
    PREDICTION = "PREDICTION"  # A statement about a future event

In [None]:
from pydantic import BaseModel, field_validator

# This model defines the structure for a single extracted statement
class RawStatement(BaseModel):
    statement: str
    statement_type: StatementType
    temporal_type: TemporalType

# This model is a container for the list of statements from one chunk
class RawStatementList(BaseModel):
    statements: list[RawStatement]

In [None]:
# These definitions provide the necessary context for the LLM to understand the labels.
LABEL_DEFINITIONS: dict[str, dict[str, dict[str, str]]] = {
    "episode_labelling": {
        "FACT": dict(definition="Statements that are objective and can be independently verified or falsified through evidence."),
        "OPINION": dict(definition="Statements that contain personal opinions, feelings, values, or judgments that are not independently verifiable."),
        "PREDICTION": dict(definition="Uncertain statements about the future on something that might happen, a hypothetical outcome, unverified claims."),
    },
    "temporal_labelling": {
        "STATIC": dict(definition="Often past tense, think -ed verbs, describing single points-in-time."),
        "DYNAMIC": dict(definition="Often present tense, think -ing verbs, describing a period of time."),
        "ATEMPORAL": dict(definition="Statements that will always hold true regardless of time."),
    },
}

In [None]:
# Format label definitions into a clean string for prompt injection
definitions_text = ""

for section_key, section_dict in LABEL_DEFINITIONS.items():
    # Add a section header with underscores replaced by spaces and uppercased
    definitions_text += f"==== {section_key.replace('_', ' ').upper()} DEFINITIONS ====\n"

    # Add each category and its definition under the section
    for category, details in section_dict.items():
        definitions_text += f"- {category}: {details.get('definition', '')}\n"

In [None]:
from langchain_core.prompts import ChatPromptTemplate

# Define the prompt template for statement extraction and labeling
statement_extraction_prompt_template = """
You are an expert extracting atomic statements from text.

Inputs:
- main_entity: {main_entity}
- document_chunk: {document_chunk}

Tasks:
1. Extract clear, single-subject statements.
2. Label each as FACT, OPINION, or PREDICTION.
3. Label each temporally as STATIC, DYNAMIC, or ATEMPORAL.
4. Resolve references to main_entity and include dates/quantities.

Return ONLY a JSON object with the statements and labels.
"""

# Create a ChatPromptTemplate from the string template
prompt = ChatPromptTemplate.from_template(statement_extraction_prompt_template)


In [None]:
from langchain_nebius import ChatNebius
import json

# Initialize our LLM
llm = ChatNebius(model="deepseek-ai/DeepSeek-V3")

# Create the chain: prompt -> LLM -> structured output parser
statement_extraction_chain = prompt | llm.with_structured_output(RawStatementList)

In [None]:
# Select the sample chunk we inspected earlier for testing extraction
sample_chunk_for_extraction = chunked_documents_lc[10]

print("--- Running statement extraction on a sample chunk ---")
print(f"Chunk Content:\n{sample_chunk_for_extraction.page_content}")
print("\nInvoking LLM for extraction...")

# Call the extraction chain with necessary inputs
extracted_statements_list = statement_extraction_chain.invoke({
    "main_entity": sample_chunk_for_extraction.metadata["company"],
    "publication_date": sample_chunk_for_extraction.metadata["date"].isoformat(),
    "document_chunk": sample_chunk_for_extraction.page_content,
    "definitions": definitions_text
})

print("\n--- Extraction Result ---")
# Pretty-print the output JSON from the model response
print(extracted_statements_list.model_dump_json(indent=2))

--- Running statement extraction on a sample chunk ---
Chunk Content:
No, that's a fair question, Matt. So we have been very focused on the server launch for first half of 2017. Desktop should launch before that. In terms of true volume availability, I believe it will be in the first quarter of 2017. We may ship some limited volume towards the end of the fourth quarter, based on how bring-up goes and the customer readiness.\nBut again, if I look overall at what we are trying to do, I think the desktop product is very well-positioned for that high-end desktop segment, that enthusiast segment, in both channel and OEM, which is very much a segment that AMD knows well. And so that's where we would focus -- on desktop.\nYou should expect a notebook version of Zen with integrated graphics in 2017, and that development is going on as well. And so I think it's just a time of a lot of activity around the Zen and the different Zen product families.\n\n--------------------------------------------

## Pinpointing Time with a Validation Check Agent

The next stage is to attach precise timestamps to each fact so that our knowledge base becomes temporal. I.e. you might have a statement "AMD has been focused on server launch for the first half of 2017" which doesn't have a proper time anchor. To solve this we build a specialised date agent which (1) uses the publication date of the document as a reference point (2) interprets fuzzy phrases like "next quarter" or "three months ago" and (3) converts them to exact time ranges 2017-01-01 'valid at' and 2017-06-31 'invalid at' (when it became true) which is when the statement stopped being rue (if it has).

In [None]:
from datetime import datetime, timezone
from dateutil.parser import parse
import re

def parse_date_str(value: str | datetime | None) -> datetime | None:
    """
    Parse a string or datetime into a timezone-aware datetime object (UTC).
    Returns None if parsing fails or input is None.
    """
    if not value:
        return None

    # If already a datetime, ensure it has timezone info (UTC if missing)
    if isinstance(value, datetime):
        return value if value.tzinfo else value.replace(tzinfo=timezone.utc)

    try:
        # Handle year-only strings like "2023"
        if re.fullmatch(r"\d{4}", value.strip()):
            year = int(value.strip())
            return datetime(year, 1, 1, tzinfo=timezone.utc)

        # Parse more complex date strings with dateutil
        dt: datetime = parse(value)

        # Ensure timezone-aware, default to UTC if missing
        if dt.tzinfo is None:
            dt = dt.replace(tzinfo=timezone.utc)

        return dt
    except Exception:
        return None

In [None]:
from pydantic import BaseModel, Field, field_validator
from datetime import datetime

# Model for raw temporal range with date strings as ISO 8601
class RawTemporalRange(BaseModel):
    valid_at: str | None = Field(None, description="The start date/time in ISO 8601 format.")
    invalid_at: str | None = Field(None, description="The end date/time in ISO 8601 format.")

# Model for validated temporal range with datetime objects
class TemporalValidityRange(BaseModel):
    valid_at: datetime | None = None
    invalid_at: datetime | None = None

    # Validator to parse date strings into datetime objects before assignment
    @field_validator("valid_at", "invalid_at", mode="before")
    @classmethod
    def _parse_date_string(cls, value: str | datetime | None) -> datetime | None:
        return parse_date_str(value)

In [None]:
# Prompt guiding the LLM to extract temporal validity ranges from statements
date_extraction_prompt_template = """
You are a temporal information extraction specialist.

INPUTS:
- statement: "{statement}"
- statement_type: "{statement_type}"
- temporal_type: "{temporal_type}"
- publication_date: "{publication_date}"
- quarter: "{quarter}"

TASK:
- Analyze the statement and determine the temporal validity range (valid_at, invalid_at).
- Use the publication date as the reference point for relative expressions (e.g., "currently").
- If a relationship is ongoing or its end is not specified, `invalid_at` should be null.

GUIDANCE:
- For STATIC statements, `valid_at` is the date the event occurred, and `invalid_at` is null.
- For DYNAMIC statements, `valid_at` is when the state began, and `invalid_at` is when it ended.
- Return dates in ISO 8601 format (e.g., YYYY-MM-DDTHH:MM:SSZ).

**Output format**
Return ONLY a valid JSON object matching the schema for `RawTemporalRange`.
"""

# Create a LangChain prompt template from the string
date_extraction_prompt = ChatPromptTemplate.from_template(date_extraction_prompt_template)

In [None]:
# Reuse the existing LLM instance.
# Create a chain by connecting the date extraction prompt
# with the LLM configured to output structured RawTemporalRange objects.
date_extraction_chain = date_extraction_prompt | llm.with_structured_output(RawTemporalRange)

In [None]:
# Take the first extracted statement for date extraction testing
sample_statement = extracted_statements_list.statements[0]
chunk_metadata = sample_chunk_for_extraction.metadata

print(f"--- Running date extraction for statement ---")
print(f'Statement: "{sample_statement.statement}"')
print(f"Reference Publication Date: {chunk_metadata['date'].isoformat()}")

# Invoke the date extraction chain with relevant inputs
raw_temporal_range = date_extraction_chain.invoke({
    "statement": sample_statement.statement,
    "statement_type": sample_statement.statement_type.value,
    "temporal_type": sample_statement.temporal_type.value,
    "publication_date": chunk_metadata["date"].isoformat(),
    "quarter": chunk_metadata["quarter"]
})

# Parse and validate raw LLM output into a structured TemporalValidityRange model
final_temporal_range = TemporalValidityRange.model_validate(raw_temporal_range.model_dump())

print("\n--- Parsed & Validated Result ---")
print(f"Valid At: {final_temporal_range.valid_at}")
print(f"Invalid At: {final_temporal_range.invalid_at}")

--- Running date extraction for statement ---
Statement: "AMD has been very focused on the server launch for the first half of 2017."
Reference Publication Date: 2016-07-21

--- Parsed & Validated Result ---
Valid At: 2017-01-01 00:00:00+00:00
Invalid At: 2017-06-30 23:59:59+00:00


## Structuring Facts into Triplets

At this stage we know the facts (the what) and the dates (the when) we've already extracted and break them into a triplet structure. (1) the subject - the main entity (AMD) (2) the predicate - the relationship or action (3) the object - the related entity or concept. This format is later helpful to connect them into a knowledge graph.

In [None]:
from enum import Enum  # Import the Enum base class to create enumerated constants

# Enum representing a fixed set of relationship predicates for graph consistency
class Predicate(str, Enum):
    # Each member of this Enum represents a specific type of relationship between entities
    IS_A = "IS_A"                # Represents an "is a" relationship (e.g., a Dog IS_A Animal)
    HAS_A = "HAS_A"              # Represents possession or composition (e.g., a Car HAS_A Engine)
    LOCATED_IN = "LOCATED_IN"    # Represents location relationship (e.g., Store LOCATED_IN City)
    HOLDS_ROLE = "HOLDS_ROLE"    # Represents role or position held (e.g., Person HOLDS_ROLE Manager)
    PRODUCES = "PRODUCES"        # Represents production or creation relationship
    SELLS = "SELLS"              # Represents selling relationship between entities
    LAUNCHED = "LAUNCHED"        # Represents launch events (e.g., Product LAUNCHED by Company)
    DEVELOPED = "DEVELOPED"      # Represents development relationship (e.g., Software DEVELOPED by Team)
    ADOPTED_BY = "ADOPTED_BY"    # Represents adoption relationship (e.g., Policy ADOPTED_BY Organization)
    INVESTS_IN = "INVESTS_IN"    # Represents investment relationships (e.g., Company INVESTS_IN Startup)
    COLLABORATES_WITH = "COLLABORATES_WITH"  # Represents collaboration between entities
    SUPPLIES = "SUPPLIES"        # Represents supplier relationship (e.g., Supplier SUPPLIES Parts)
    HAS_REVENUE = "HAS_REVENUE"  # Represents revenue relationship for entities
    INCREASED = "INCREASED"      # Represents an increase event or metric change
    DECREASED = "DECREASED"      # Represents a decrease event or metric change
    RESULTED_IN = "RESULTED_IN"  # Represents causal relationship (e.g., Event RESULTED_IN Outcome)
    TARGETS = "TARGETS"          # Represents target or goal relationship
    PART_OF = "PART_OF"          # Represents part-whole relationship (e.g., Wheel PART_OF Car)
    DISCONTINUED = "DISCONTINUED" # Represents discontinued status or event
    SECURED = "SECURED"          # Represents secured or obtained relationship (e.g., Funding SECURED by Company)

In [None]:
from pydantic import BaseModel, Field
from typing import List, Optional

# Model representing an entity extracted by the LLM
class RawEntity(BaseModel):
    entity_idx: int = Field(description="A temporary, 0-indexed ID for this entity.")
    name: str = Field(description="The name of the entity, e.g., 'AMD' or 'Lisa Su'.")
    type: str = Field("Unknown", description="The type of entity, e.g., 'Organization', 'Person'.")
    description: str = Field("", description="A brief description of the entity.")

# Model representing a single subject-predicate-object triplet
class RawTriplet(BaseModel):
    subject_name: str
    subject_id: int = Field(description="The entity_idx of the subject.")
    predicate: Predicate
    object_name: str
    object_id: int = Field(description="The entity_idx of the object.")
    value: Optional[str] = Field(None, description="An optional value, e.g., '10%'.")

# Container for all entities and triplets extracted from a single statement
class RawExtraction(BaseModel):
    entities: List[RawEntity]
    triplets: List[RawTriplet]

In [None]:
# These definitions guide the LLM in choosing the correct predicate.
PREDICATE_DEFINITIONS = {
    "IS_A": "Denotes a class-or-type relationship (e.g., 'Model Y IS_A electric-SUV').",
    "HAS_A": "Denotes a part-whole relationship (e.g., 'Model Y HAS_A electric-engine').",
    "LOCATED_IN": "Specifies geographic or organisational containment.",
    "HOLDS_ROLE": "Connects a person to a formal office or title.",
}

# Format the predicate instructions into a string for the prompt.
predicate_instructions_text = "\n".join(f"- {pred}: {desc}" for pred, desc in PREDICATE_DEFINITIONS.items())

In [None]:
# Prompt for extracting entities and subject-predicate-object triplets from a statement
triplet_extraction_prompt_template = """
You are an information-extraction assistant.

Task: From the statement, identify all entities (people, organizations, products, concepts) and all triplets (subject, predicate, object) describing their relationships.

Statement: "{statement}"

Predicate list:
{predicate_instructions}

Guidelines:
- List entities with unique `entity_idx`.
- List triplets linking subjects and objects by `entity_idx`.
- Exclude temporal expressions from entities and triplets.

Example:
Statement: "Google's revenue increased by 10% from January through March."
Output: {{
  "entities": [
    {{"entity_idx": 0, "name": "Google", "type": "Organization", "description": "A multinational technology company."}},
    {{"entity_idx": 1, "name": "Revenue", "type": "Financial Metric", "description": "Income from normal business."}}
  ],
  "triplets": [
    {{"subject_name": "Google", "subject_id": 0, "predicate": "INCREASED", "object_name": "Revenue", "object_id": 1, "value": "10%"}}
  ]
}}

Return ONLY a valid JSON object matching `RawExtraction`.
"""

# Initializing the prompt template
triplet_extraction_prompt = ChatPromptTemplate.from_template(triplet_extraction_prompt_template)

In [None]:
# Create the chain for triplet and entity extraction.
triplet_extraction_chain = triplet_extraction_prompt | llm.with_structured_output(RawExtraction)

# Let's use the same statement we've been working with.
sample_statement_for_triplets = extracted_statements_list.statements[0]

print(f"--- Running triplet extraction for statement ---")
print(f'Statement: "{sample_statement_for_triplets.statement}"')

# Invoke the chain.
raw_extraction_result = triplet_extraction_chain.invoke({
    "statement": sample_statement_for_triplets.statement,
    "predicate_instructions": predicate_instructions_text
})

print("\n--- Triplet Extraction Result ---")
print(raw_extraction_result.model_dump_json(indent=2))

--- Running triplet extraction for statement ---
Statement: "AMD has been very focused on the server launch for the first half of 2017."

--- Triplet Extraction Result ---
{
  "entities": [
    {
      "entity_idx": 0,
      "name": "AMD",
      "type": "Organization",
      "description": "A multinational semiconductor company."
    },
    {
      "entity_idx": 1,
      "name": "server launch",
      "type": "Event",
      "description": "The release of server-related products."
    }
  ],
  "triplets": [
    {
      "subject_name": "AMD",
      "subject_id": 0,
      "predicate": "HAS_A",
      "object_name": "server launch",
      "object_id": 1,
      "value": null
    }
  ]
}


## Assembling the Temporal Event

This stage then merges the statements, dates and entities and triplets into a single object called a Temporal Event. This acts as the master record that holds the original statement, type (fact, dynamic etc), validity dates, associated triplets and metadata. The temporal event assigns a UUID (universally unique identifier) as its main identifier so each fact can be tracked across the knowledge graph.

In [None]:
import uuid
from pydantic import BaseModel, Field

# Final persistent model for an entity in your knowledge graph
class Entity(BaseModel):
    id: uuid.UUID = Field(default_factory=uuid.uuid4, description="Unique UUID for the entity")
    name: str = Field(..., description="The name of the entity")
    type: str = Field(..., description="Entity type, e.g., 'Organization', 'Person'")
    description: str = Field("", description="Brief description of the entity")
    resolved_id: uuid.UUID | None = Field(None, description="UUID of resolved entity if merged")

# Final persistent model for a triplet relationship
class Triplet(BaseModel):
    id: uuid.UUID = Field(default_factory=uuid.uuid4, description="Unique UUID for the triplet")
    subject_name: str = Field(..., description="Name of the subject entity")
    subject_id: uuid.UUID = Field(..., description="UUID of the subject entity")
    predicate: Predicate = Field(..., description="Relationship predicate")
    object_name: str = Field(..., description="Name of the object entity")
    object_id: uuid.UUID = Field(..., description="UUID of the object entity")
    value: str | None = Field(None, description="Optional value associated with the triplet")

In [None]:
class TemporalEvent(BaseModel):
    """
    The central model that consolidates all extracted information.
    """
    id: uuid.UUID = Field(default_factory=uuid.uuid4)
    chunk_id: uuid.UUID # To link back to the original text chunk
    statement: str
    embedding: list[float] = [] # For similarity checks later

    # Information from our previous extraction steps
    statement_type: StatementType
    temporal_type: TemporalType
    valid_at: datetime | None = None
    invalid_at: datetime | None = None

    # A list of the IDs of the triplets associated with this event
    triplets: list[uuid.UUID]

    # Extra metadata for tracking changes over time
    created_at: datetime = Field(default_factory=datetime.now)
    expired_at: datetime | None = None
    invalidated_by: uuid.UUID | None = None

In [None]:
# Assume these are already defined from previous steps:
# sample_statement, final_temporal_range, raw_extraction_result

print("--- Assembling the final TemporalEvent ---")

# 1. Convert raw entities to persistent Entity objects with UUIDs
idx_to_entity_map: dict[int, Entity] = {}
final_entities: list[Entity] = []

for raw_entity in raw_extraction_result.entities:
    entity = Entity(
        name=raw_entity.name,
        type=raw_entity.type,
        description=raw_entity.description
    )
    idx_to_entity_map[raw_entity.entity_idx] = entity
    final_entities.append(entity)

print(f"Created {len(final_entities)} persistent Entity objects.")

# 2. Convert raw triplets to persistent Triplet objects, linking entities via UUIDs
final_triplets: list[Triplet] = []

for raw_triplet in raw_extraction_result.triplets:
    subject_entity = idx_to_entity_map[raw_triplet.subject_id]
    object_entity = idx_to_entity_map[raw_triplet.object_id]

    triplet = Triplet(
        subject_name=subject_entity.name,
        subject_id=subject_entity.id,
        predicate=raw_triplet.predicate,
        object_name=object_entity.name,
        object_id=object_entity.id,
        value=raw_triplet.value
    )
    final_triplets.append(triplet)

print(f"Created {len(final_triplets)} persistent Triplet objects.")

--- Assembling the final TemporalEvent ---
Created 2 persistent Entity objects.
Created 1 persistent Triplet objects.


In [None]:
# 3. Create the final TemporalEvent object
# We'll generate a dummy chunk_id for this example.
temporal_event = TemporalEvent(
    chunk_id=uuid.uuid4(), # Placeholder ID
    statement=sample_statement.statement,
    statement_type=sample_statement.statement_type,
    temporal_type=sample_statement.temporal_type,
    valid_at=final_temporal_range.valid_at,
    invalid_at=final_temporal_range.invalid_at,
    triplets=[t.id for t in final_triplets]
)

print("\n--- Final Assembled TemporalEvent ---")
print(temporal_event.model_dump_json(indent=2))

print("\n--- Associated Entities ---")
for entity in final_entities:
    print(entity.model_dump_json(indent=2))

print("\n--- Associated Triplets ---")
for triplet in final_triplets:
    print(triplet.model_dump_json(indent=2))


--- Final Assembled TemporalEvent ---
{
  "id": "d7a2bc93-f1c2-4356-8281-4c64948b1a6b",
  "chunk_id": "f4669226-e138-4bd4-a4e2-e6a8393fdaea",
  "statement": "AMD has been very focused on the server launch for the first half of 2017.",
  "embedding": [],
  "statement_type": "FACT",
  "temporal_type": "DYNAMIC",
  "valid_at": "2017-01-01T00:00:00Z",
  "invalid_at": "2017-06-30T23:59:59Z",
  "triplets": [
    "6d9294df-9693-4c47-875f-49d08083289c"
  ],
  "created_at": "2025-08-16T21:07:15.176944",
  "expired_at": null,
  "invalidated_by": null
}

--- Associated Entities ---
{
  "id": "ba639b08-b69e-4c9b-9189-643a5decac14",
  "name": "AMD",
  "type": "Organization",
  "description": "A multinational semiconductor company.",
  "resolved_id": null
}
{
  "id": "b93aa6ba-3868-4bf6-857b-658049af64e2",
  "name": "server launch",
  "type": "Event",
  "description": "The release of server-related products.",
  "resolved_id": null
}

--- Associated Triplets ---
{
  "id": "6d9294df-9693-4c47-875f-4

## Automating the Pipeline with LangGraph

Up to now the statements, dates and triplets were extracted manually and assembled into TemporalEvents - to scale this across thousands of chunks the process is automated using LangGraph (which is a library for building Al workflows as directed graphs). With LangGraph, each stage like extracting statements, pulling out dates and building triplets becomes a separate node in a graph. The workflow is defined with a start point and end point to ensure we can scale across all the document chunks quickly. In this case one run, 19 chunks produced 95 structured temporal events, 213 entities and 121 triplets.

In [None]:
from typing import List, TypedDict
from langchain_core.documents import Document

class GraphState(TypedDict):
    """
    TypedDict representing the overall state of the knowledge graph ingestion.

    Attributes:
        chunks: List of Document chunks being processed.
        temporal_events: List of TemporalEvent objects extracted from chunks.
        entities: List of Entity objects in the graph.
        triplets: List of Triplet objects representing relationships.
    """
    chunks: List[Document]
    temporal_events: List[TemporalEvent]
    entities: List[Entity]
    triplets: List[Triplet]

In [None]:
def extract_events_from_chunks(state: GraphState) -> GraphState:
    chunks = state["chunks"]

    # Extract raw statements from each chunk
    raw_stmts = statement_extraction_chain.batch([{
        "main_entity": c.metadata["company"],
        "publication_date": c.metadata["date"].isoformat(),
        "document_chunk": c.page_content,
        "definitions": definitions_text
    } for c in chunks])

    # Flatten statements, attach metadata and unique chunk IDs
    stmts = [{"raw": s, "meta": chunks[i].metadata, "cid": uuid.uuid4()}
             for i, rs in enumerate(raw_stmts) for s in rs.statements]

    # Prepare inputs and batch extract temporal data
    dates = date_extraction_chain.batch([{
        "statement": s["raw"].statement,
        "statement_type": s["raw"].statement_type.value,
        "temporal_type": s["raw"].temporal_type.value,
        "publication_date": s["meta"]["date"].isoformat(),
        "quarter": s["meta"]["quarter"]
    } for s in stmts])

    # Prepare inputs and batch extract triplets
    trips = triplet_extraction_chain.batch([{
        "statement": s["raw"].statement,
        "predicate_instructions": predicate_instructions_text
    } for s in stmts])

    events, entities, triplets = [], [], []

    for i, s in enumerate(stmts):
        # Validate temporal range data
        tr = TemporalValidityRange.model_validate(dates[i].model_dump())
        ext = trips[i]

        # Map entities by index and collect them
        idx_map = {e.entity_idx: Entity(e.name, e.type, e.description) for e in ext.entities}
        entities.extend(idx_map.values())

        # Build triplets only if subject and object entities exist
        tpls = [Triplet(
            idx_map[t.subject_id].name, idx_map[t.subject_id].id, t.predicate,
            idx_map[t.object_id].name, idx_map[t.object_id].id, t.value)
            for t in ext.triplets if t.subject_id in idx_map and t.object_id in idx_map]
        triplets.extend(tpls)

        # Create TemporalEvent with linked triplet IDs
        events.append(TemporalEvent(
            chunk_id=s["cid"], statement=s["raw"].statement,
            statement_type=s["raw"].statement_type, temporal_type=s["raw"].temporal_type,
            valid_at=tr.valid_at, invalid_at=tr.invalid_at,
            triplets=[t.id for t in tpls]
        ))

    return {"chunks": chunks, "temporal_events": events, "entities": entities, "triplets": triplets}

In [None]:
!pip install -U langgraph

Collecting langgraph
  Downloading langgraph-0.6.5-py3-none-any.whl.metadata (6.8 kB)
Collecting langgraph-checkpoint<3.0.0,>=2.1.0 (from langgraph)
  Downloading langgraph_checkpoint-2.1.1-py3-none-any.whl.metadata (4.2 kB)
Collecting langgraph-prebuilt<0.7.0,>=0.6.0 (from langgraph)
  Downloading langgraph_prebuilt-0.6.4-py3-none-any.whl.metadata (4.5 kB)
Collecting langgraph-sdk<0.3.0,>=0.2.0 (from langgraph)
  Downloading langgraph_sdk-0.2.0-py3-none-any.whl.metadata (1.5 kB)
Collecting ormsgpack>=1.10.0 (from langgraph-checkpoint<3.0.0,>=2.1.0->langgraph)
  Downloading ormsgpack-1.10.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.7/43.7 kB[0m [31m865.9 kB/s[0m eta [36m0:00:00[0m
Downloading langgraph-0.6.5-py3-none-any.whl (153 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m153.2/153.2 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading langgraph_che

In [None]:
from langgraph.graph import StateGraph, END

# Define a new graph using our state
workflow = StateGraph(GraphState)

# Add our function as a node named "extract_events"
workflow.add_node("extract_events", extract_events_from_chunks)

# Define the starting point of the graph
workflow.set_entry_point("extract_events")

# Define the end point of the graph
workflow.add_edge("extract_events", END)

# Compile the graph into a runnable application
app = workflow.compile()

In [None]:
# The input is a dictionary matching our GraphState, providing the initial chunks
graph_input = {"chunks": chunked_documents_lc}

# Invoke the graph. This will run our entire extraction pipeline in one call.
final_state = app.invoke(graph_input)


APIStatusError: Error code: 402 - {'detail': 'Payment Required: You have exhausted your budget. Please add funds to continue using the API.'}

In [None]:
# Check the number of objects created in the final state
num_events = len(final_state['temporal_events'])
num_entities = len(final_state['entities'])
num_triplets = len(final_state['triplets'])

print(f"Total TemporalEvents created: {num_events}")
print(f"Total Entities created: {num_entities}")
print(f"Total Triplets created: {num_triplets}")

print("\n--- Sample TemporalEvent from the final state ---")
# Print a sample event to see the fully assembled object
print(final_state['temporal_events'][5].model_dump_json(indent=2))

## Cleaning Our Data with Entity Resolution

For the next stage, dealing with the same entity that comes up with multiple names, we use entity resolution. This means clustering together different mentions and assigning them to one clean ID and adding them to a new LangGraph node (1) group entities by type (person/company) (2) use fuzzy string matching to detect duplicates (3) pick a clean version for each cluster (4) update all triplets using the clean ID instead of duplicates — we store the clean entity names in a SQL database which uses the new entity resolution.

In [None]:
import sqlite3

def setup_in_memory_db():
    """
    Sets up an in-memory SQLite database and creates the 'entities' table.

    The 'entities' table schema:
    - id: TEXT, Primary Key
    - name: TEXT, name of the entity
    - type: TEXT, type/category of the entity
    - description: TEXT, description of the entity
    - is_canonical: INTEGER, flag to indicate if entity is canonical (default 1)

    Returns:
        sqlite3.Connection: A connection object to the in-memory database.
    """
    # Establish connection to an in-memory SQLite database
    conn = sqlite3.connect(":memory:")

    # Create a cursor object to execute SQL commands
    cursor = conn.cursor()

    # Create the 'entities' table if it doesn't already exist
    cursor.execute("""
        CREATE TABLE IF NOT EXISTS entities (
            id TEXT PRIMARY KEY,
            name TEXT,
            type TEXT,
            description TEXT,
            is_canonical INTEGER DEFAULT 1
        )
    """)

    # Commit changes to save the table schema
    conn.commit()

    # Return the connection object for further use
    return conn

# Create the database connection and set up the entities table
db_conn = setup_in_memory_db()

In [None]:
import string
from rapidfuzz import fuzz
from collections import defaultdict

def resolve_entities_in_state(state: GraphState) -> GraphState:
    """
    A LangGraph node to perform entity resolution on the extracted entities.
    """
    print("\n--- Entering Node: resolve_entities_in_state ---")
    entities = state["entities"]
    triplets = state["triplets"]

    cursor = db_conn.cursor()
    cursor.execute("SELECT id, name FROM entities WHERE is_canonical = 1")
    global_canonicals = {row[1]: uuid.UUID(row[0]) for row in cursor.fetchall()}

    print(f"Starting resolution with {len(entities)} entities. Found {len(global_canonicals)} canonicals in DB.")

    # Group entities by type (e.g., 'Person', 'Organization') for more accurate matching
    type_groups = defaultdict(list)
    for entity in entities:
        type_groups[entity.type].append(entity)

    resolved_id_map = {} # Maps an old entity ID to its new canonical ID
    newly_created_canonicals = {}

    for entity_type, group in type_groups.items():
        if not group: continue

        # Cluster entities in the group by fuzzy name matching
        clusters = []
        used_indices = set()
        for i in range(len(group)):
            if i in used_indices: continue
            current_cluster = [group[i]]
            used_indices.add(i)
            for j in range(i + 1, len(group)):
                if j in used_indices: continue
                # Use partial_ratio for flexible matching (e.g., "AMD" vs "Advanced Micro Devices, Inc.")
                score = fuzz.partial_ratio(group[i].name.lower(), group[j].name.lower())
                if score >= 80.0: # A similarity threshold of 80%
                    current_cluster.append(group[j])
                    used_indices.add(j)
            clusters.append(current_cluster)

        # For each cluster, find the best canonical representation (the "medoid")
        for cluster in clusters:
            scores = {e.name: sum(fuzz.ratio(e.name.lower(), other.name.lower()) for other in cluster) for e in cluster}
            medoid_entity = max(cluster, key=lambda e: scores[e.name])
            canonical_name = medoid_entity.name

            # Check if this canonical name already exists or was just created in this run
            if canonical_name in global_canonicals:
                canonical_id = global_canonicals[canonical_name]
            elif canonical_name in newly_created_canonicals:
                canonical_id = newly_created_canonicals[canonical_name].id
            else:
                # Create a new canonical entity
                canonical_id = medoid_entity.id
                newly_created_canonicals[canonical_name] = medoid_entity

            # Map all entities in this cluster to the single canonical ID
            for entity in cluster:
                entity.resolved_id = canonical_id
                resolved_id_map[entity.id] = canonical_id

    # Update the triplets in our state to use the new canonical IDs
    for triplet in triplets:
        if triplet.subject_id in resolved_id_map:
            triplet.subject_id = resolved_id_map[triplet.subject_id]
        if triplet.object_id in resolved_id_map:
            triplet.object_id = resolved_id_map[triplet.object_id]

    # Add any newly created canonical entities to our database
    if newly_created_canonicals:
        print(f"Adding {len(newly_created_canonicals)} new canonical entities to the DB.")
        new_data = [(str(e.id), e.name, e.type, e.description, 1) for e in newly_created_canonicals.values()]
        cursor.executemany("INSERT INTO entities (id, name, type, description, is_canonical) VALUES (?, ?, ?, ?, ?)", new_data)
        db_conn.commit()

    print("Entity resolution complete.")
    return state

In [None]:
# Re-define the graph to include the new node
workflow = StateGraph(GraphState)

# Add our two nodes to the graph
workflow.add_node("extract_events", extract_events_from_chunks)
workflow.add_node("resolve_entities", resolve_entities_in_state)

# Define the new sequence of steps
workflow.set_entry_point("extract_events")
workflow.add_edge("extract_events", "resolve_entities")
workflow.add_edge("resolve_entities", END)

# Compile the updated workflow
app_with_resolution = workflow.compile()

In [None]:
# Use the same input as before
graph_input = {"chunks": chunked_documents_lc}

# Invoke the new graph
final_state_with_resolution = app_with_resolution.invoke(graph_input)

In [None]:
# Find a sample entity that has been resolved (i.e., has a resolved_id)
sample_resolved_entity = next((e for e in final_state_with_resolution['entities'] if e.resolved_id is not None and e.id != e.resolved_id), None)

if sample_resolved_entity:
    print("\n--- Sample of a Resolved Entity ---")
    print(sample_resolved_entity.model_dump_json(indent=2))
else:
    print("\nNo sample resolved entity found (all entities were unique in this small run).")

# Check a triplet to see its updated canonical IDs
sample_resolved_triplet = final_state_with_resolution['triplets'][0]
print("\n--- Sample Triplet with Resolved IDs ---")
print(sample_resolved_triplet.model_dump_json(indent=2))

## Making Our Knowledge Dynamic with an Invalidation Agent

The biggest challenge is that facts change over time - which is now where the invalidation agent comes in. Its job is to act like a referee: it spots contradictions between old and new dynamic facts and if something has changes, it updates the old facts with an invalid at timestamp. The process works like (1) embedding all statements to measure semantic similarity (2) comparing new dynamic facts against existing ones (3) using an LLM to judge whether the new fact invalidates the old (4) if so, mark the old fact as expired. This ensures the knowledge base is dynamic and time-aware.

In [None]:
# Obtain a cursor from the existing database connection
cursor = db_conn.cursor()

# Create the 'events' table to store event-related data
cursor.execute("""
CREATE TABLE IF NOT EXISTS events (
    id TEXT PRIMARY KEY,         -- Unique identifier for each event
    chunk_id TEXT,               -- Identifier for the chunk this event belongs to
    statement TEXT,              -- Textual representation of the event
    statement_type TEXT,         -- Type/category of the statement (e.g., assertion, question)
    temporal_type TEXT,          -- Temporal classification (e.g., past, present, future)
    valid_at TEXT,               -- Timestamp when the event becomes valid
    invalid_at TEXT,             -- Timestamp when the event becomes invalid
    embedding BLOB               -- Optional embedding data stored as binary (e.g., vector)
)
""")

# Create the 'triplets' table to store relations between entities for events
cursor.execute("""
CREATE TABLE IF NOT EXISTS triplets (
    id TEXT PRIMARY KEY,         -- Unique identifier for each triplet
    event_id TEXT,               -- Foreign key referencing 'events.id'
    subject_id TEXT,             -- Subject entity ID in the triplet
    predicate TEXT               -- Predicate describing relation or action
)
""")

# Commit all changes to the in-memory database
db_conn.commit()


In [None]:
# This prompt asks the LLM to act as a referee between two events.
event_invalidation_prompt_template = """
Task: Analyze the primary event against the secondary event and determine if the primary event is invalidated by the secondary event.
Return "True" if the primary event is invalidated, otherwise return "False".

Invalidation Guidelines:
1. An event can only be invalidated if it is DYNAMIC and its `invalid_at` is currently null.
2. A STATIC event (e.g., "X was hired on date Y") can invalidate a DYNAMIC event (e.g., "Z is the current employee").
3. Invalidation must be a direct contradiction. For example, "Lisa Su is CEO" is contradicted by "Someone else is CEO".
4. The invalidating event (secondary) must occur at or after the start of the primary event.

---
Primary Event (the one that might be invalidated):
- Statement: {primary_statement}
- Type: {primary_temporal_type}
- Valid From: {primary_valid_at}
- Valid To: {primary_invalid_at}

Secondary Event (the new fact that might cause invalidation):
- Statement: {secondary_statement}
- Type: {secondary_temporal_type}
- Valid From: {secondary_valid_at}
---

Is the primary event invalidated by the secondary event? Answer with only "True" or "False".
"""

invalidation_prompt = ChatPromptTemplate.from_template(event_invalidation_prompt_template)

# This chain will output a simple string: "True" or "False".
invalidation_chain = invalidation_prompt | llm

In [None]:
from scipy.spatial.distance import cosine

def invalidate_events_in_state(state: GraphState) -> GraphState:
    """Mark dynamic events invalidated by later similar facts."""
    events = state["temporal_events"]

    # Embed all event statements
    embeds = embeddings.embed_documents([e.statement for e in events])
    for e, emb in zip(events, embeds):
        e.embedding = emb

    updates = {}
    for i, e1 in enumerate(events):
        # Skip non-dynamic or already invalidated events
        if e1.temporal_type != TemporalType.DYNAMIC or e1.invalid_at:
            continue

        # Find candidate events: facts starting at or after e1 with high similarity
        cands = [
            e2 for j, e2 in enumerate(events) if j != i and
            e2.statement_type == StatementType.FACT and e2.valid_at and e1.valid_at and
            e2.valid_at >= e1.valid_at and 1 - cosine(e1.embedding, e2.embedding) > 0.5
        ]
        if not cands:
            continue

        # Prepare inputs for LLM invalidation check
        inputs = [{
            "primary_statement": e1.statement, "primary_temporal_type": e1.temporal_type.value,
            "primary_valid_at": e1.valid_at.isoformat(), "primary_invalid_at": "None",
            "secondary_statement": c.statement, "secondary_temporal_type": c.temporal_type.value,
            "secondary_valid_at": c.valid_at.isoformat()
        } for c in cands]

        # Ask LLM which candidates invalidate the event
        results = invalidation_chain.batch(inputs)

        # Record earliest invalidation info
        for c, r in zip(cands, results):
            if r.content.strip().lower() == "true" and (e1.id not in updates or c.valid_at < updates[e1.id]["invalid_at"]):
                updates[e1.id] = {"invalid_at": c.valid_at, "invalidated_by": c.id}

    # Apply invalidations to events
    for e in events:
        if e.id in updates:
            e.invalid_at = updates[e.id]["invalid_at"]
            e.invalidated_by = updates[e.id]["invalidated_by"]

    return state

In [None]:
# Re-define the graph to include all three nodes
workflow = StateGraph(GraphState)

workflow.add_node("extract_events", extract_events_from_chunks)
workflow.add_node("resolve_entities", resolve_entities_in_state)
workflow.add_node("invalidate_events", invalidate_events_in_state)

# Define the complete pipeline flow
workflow.set_entry_point("extract_events")
workflow.add_edge("extract_events", "resolve_entities")
workflow.add_edge("resolve_entities", "invalidate_events")
workflow.add_edge("invalidate_events", END)

# Compile the final ingestion workflow
ingestion_app = workflow.compile()

In [None]:
# Use the same input as before
graph_input = {"chunks": chunked_documents_lc}

# Invoke the final graph
final_ingested_state = ingestion_app.invoke(graph_input)
print("\n--- Full graph execution with invalidation complete ---")

In [None]:
# Find and print an invalidated event from the final state
invalidated_event = next((e for e in final_ingested_state['temporal_events'] if e.invalidated_by is not None), None)

if invalidated_event:
    print("\n--- Sample of an Invalidated Event ---")
    print(invalidated_event.model_dump_json(indent=2))

    # Find the event that caused the invalidation
    invalidating_event = next((e for e in final_ingested_state['temporal_events'] if e.id == invalidated_event.invalidated_by), None)

    if invalidating_event:
        print("\n--- Was Invalidated By this Event ---")
        print(invalidating_event.model_dump_json(indent=2))
else:
    print("\nNo invalidated events were found in this run.")

## Assembling the Temporal Knowledge Graph

At this point, we've done the hard work of extracting facts, resolving duplicates and handling invalidations. Now its time to assemble the temporal knowledge graph. So the entities become the nodes, triplets (relationships) become the edges connecting them and each edge carries rich temporal information - when it becomes valid, when it expires and the statement behind it. We use network for this, to build this structure directly from the final LangGraph pipeline output. As you cans ee, we end up with a graph with 340 nodes and 434 edges, where you can zoom in on company and instantly see the relationships over time. This temporal knowledge graph essentially becomes the 'brain' that our smart retrieval agent will query to answer nuanced and time-sensitive questions.

In [None]:
import networkx as nx
import uuid

def build_graph_from_state(state: GraphState) -> nx.MultiDiGraph:
    """
    Builds a NetworkX graph from the final state of our ingestion pipeline.
    """
    print("--- Building Knowledge Graph from final state ---")

    entities = state["entities"]
    triplets = state["triplets"]
    temporal_events = state["temporal_events"]

    # Create a quick-lookup map from an entity's ID to the entity object itself
    entity_map = {entity.id: entity for entity in entities}

    graph = nx.MultiDiGraph() # A directed graph that allows multiple edges

    # 1. Add a node for each unique, canonical entity
    canonical_ids = {e.resolved_id if e.resolved_id else e.id for e in entities}
    for canonical_id in canonical_ids:
        # Find the entity object that represents this canonical ID
        canonical_entity_obj = entity_map.get(canonical_id)
        if canonical_entity_obj:
            graph.add_node(
                str(canonical_id), # Node names in NetworkX are typically strings
                name=canonical_entity_obj.name,
                type=canonical_entity_obj.type,
                description=canonical_entity_obj.description
            )

    print(f"Added {graph.number_of_nodes()} canonical entity nodes to the graph.")

    # 2. Add an edge for each triplet, decorated with temporal info
    edges_added = 0
    event_map = {event.id: event for event in temporal_events}
    for triplet in triplets:
        # Find the parent event that this triplet belongs to
        parent_event = next((ev for ev in temporal_events if triplet.id in ev.triplets), None)
        if not parent_event: continue

        # Get the canonical IDs for the subject and object
        subject_canonical_id = str(triplet.subject_id)
        object_canonical_id = str(triplet.object_id)

        # Add the edge to the graph
        if graph.has_node(subject_canonical_id) and graph.has_node(object_canonical_id):
            edge_attrs = {
                "predicate": triplet.predicate.value, "value": triplet.value,
                "statement": parent_event.statement, "valid_at": parent_event.valid_at,
                "invalid_at": parent_event.invalid_at,
                "statement_type": parent_event.statement_type.value
            }
            graph.add_edge(
                subject_canonical_id, object_canonical_id,
                key=triplet.predicate.value, **edge_attrs
            )
            edges_added += 1

    print(f"Added {edges_added} edges (relationships) to the graph.")
    return graph

# Let's build the graph from the state we got from our LangGraph app
knowledge_graph = build_graph_from_state(final_ingested_state)

In [None]:
print(f"Graph has {knowledge_graph.number_of_nodes()} nodes and {knowledge_graph.number_of_edges()} edges.")

# Let's find the node for "AMD" by searching its 'name' attribute
amd_node_id = None
for node, data in knowledge_graph.nodes(data=True):
    if data.get('name', '').lower() == 'amd':
        amd_node_id = node
        break

if amd_node_id:
    print("\n--- Inspecting the 'AMD' node ---")
    print(f"Attributes: {knowledge_graph.nodes[amd_node_id]}")

    print("\n--- Sample Outgoing Edges from 'AMD' ---")
    for i, (u, v, data) in enumerate(knowledge_graph.out_edges(amd_node_id, data=True)):
        if i >= 3: break # Show the first 3 for brevity
        object_name = knowledge_graph.nodes[v]['name']
        print(f"Edge {i+1}: AMD --[{data['predicate']}]--> {object_name} (Valid From: {data['valid_at'].date()})")
else:
    print("Could not find a node for 'AMD'.")

In [None]:
import matplotlib.pyplot as plt

# Find the 15 most connected nodes to visualize
degrees = dict(knowledge_graph.degree())
top_nodes = sorted(degrees, key=degrees.get, reverse=True)[:15]

# Create a smaller graph containing only these top nodes
subgraph = knowledge_graph.subgraph(top_nodes)

# Draw the graph
plt.figure(figsize=(12, 12))
pos = nx.spring_layout(subgraph, k=0.8, iterations=50)
labels = {node: data['name'] for node, data in subgraph.nodes(data=True)}
nx.draw(subgraph, pos, labels=labels, with_labels=True, node_color='skyblue',
        node_size=2500, edge_color='#666666', font_size=10)
plt.title("Subgraph of Top 15 Most Connected Entities", size=16)
plt.show()

## Building and Testing A Multi-Step Retrieval Agent

The final stage moves from just storing the temporal knowledge agent to actually using it intelligently. A single step RAG can pull out only one fact at a time, but real questions often need multiple pieces of evidence stitched together. E.g. asking how AMD's focus on data centres changes from 2016 - 2017 requires pulling facts from both years, comparing them, and summarising the differences. This is where the multi-step retrieval agent comes in. it is built with three parts: a planner (which breaks the questions into steps), tools (which fetch facts from the graph), and an orchestrator (which loops between thinking and acting until the answer is ready).

In [None]:
# System prompt describes the "persona" for the LLM
initial_planner_system_prompt = (
    "You are an expert financial research assistant. "
    "Your task is to create a step-by-step plan for answering a user's question "
    "by querying a temporal knowledge graph of earnings call transcripts. "
    "The available tool is `factual_qa`, which can retrieve facts about an entity "
    "for a specific topic (predicate) within a given date range. "
    "Your plan should consist of a series of calls to this tool."
)

# Template for the user prompt — receives `user_question` dynamically
initial_planner_user_prompt_template = """
User Question: "{user_question}"

Based on this question, create a concise, step-by-step plan.
Each step should be a clear action for querying the knowledge graph.

Return only the plan under a heading 'Research tasks'.
"""

# Create a ChatPromptTemplate that combines the system persona and the user prompt.
# `from_messages` takes a list of (role, content) pairs to form the conversation context.
planner_prompt = ChatPromptTemplate.from_messages([
    ("system", initial_planner_system_prompt),          # LLM's role and behavior
    ("user", initial_planner_user_prompt_template),     # Instructions for this specific run
])

# Create a "chain" that pipes the prompt into the LLM.
# The `|` operator here is the LangChain "Runnable" syntax for composing components.
planner_chain = planner_prompt | llm

In [None]:
# Our sample user question for the retrieval agent
user_question = "How did AMD's focus on data centers evolve between 2016 and 2017?"

print(f"--- Generating plan for question: '{user_question}' ---")
plan_result = planner_chain.invoke({"user_question": user_question})
initial_plan = plan_result.content

print("\n--- Generated Plan ---")
print(initial_plan)

In [None]:
from langchain_core.tools import tool
from datetime import date
import datetime as dt # Use an alias to avoid confusion

# Helper function to parse dates robustly, even if the LLM provides different formats
def _as_datetime(ts) -> dt.datetime | None:
    if not ts: return None
    if isinstance(ts, dt.datetime): return ts
    if isinstance(ts, dt.date): return dt.datetime.combine(ts, dt.datetime.min.time())
    try:
        return dt.datetime.strptime(ts, "%Y-%m-%d")
    except (ValueError, TypeError):
        return None

@tool
def factual_qa(entity: str, start_date: date, end_date: date, predicate: str) -> str:
    """
    Queries the knowledge graph for facts about a specific entity, topic (predicate),
    and time range. Returns a formatted string of matching relationships.
    """
    print(f"\n--- TOOL CALL: factual_qa ---")
    print(f"  - Entity: {entity}, Predicate: {predicate}, Range: {start_date} to {end_date}")

    start_dt = _as_datetime(start_date).replace(tzinfo=timezone.utc)
    end_dt = _as_datetime(end_date).replace(tzinfo=timezone.utc)

    # 1. Find the entity node in the graph using a case-insensitive search
    target_node_id = next((nid for nid, data in knowledge_graph.nodes(data=True) if entity.lower() in data.get('name', '').lower()), None)
    if not target_node_id: return f"Error: Entity '{entity}' not found."

    # 2. Search all edges connected to that node for matches
    matching_edges = []
    for u, v, data in knowledge_graph.edges(target_node_id, data=True):
        if predicate.upper() in data.get('predicate', '').upper():
            valid_at = data.get('valid_at')
            if valid_at and start_dt <= valid_at <= end_dt:
                subject = knowledge_graph.nodes[u]['name']
                obj = knowledge_graph.nodes[v]['name']
                matching_edges.append(f"Fact: {subject} --[{data['predicate']}]--> {obj}")

    if not matching_edges: return f"No facts found for '{entity}' with predicate '{predicate}' in that date range."
    return "\n".join(matching_edges)

In [None]:
from langgraph.prebuilt import ToolNode
from langgraph.graph import StateGraph, END
from langchain_core.messages import BaseMessage, HumanMessage
from typing import TypedDict, List

# Define the state for our retrieval agent's memory
class AgentState(TypedDict):
    messages: List[BaseMessage]

# This is the "brain" of our agent. It decides what to do next.
def call_model(state: AgentState):
    print("\n--- AGENT: Calling model to decide next step... ---")
    response = llm_with_tools.invoke(state['messages'])
    return {"messages": [response]}

# This is a conditional edge. It checks if the LLM decided to call a tool or to finish.
def should_continue(state: AgentState) -> str:
    if hasattr(state['messages'][-1], 'tool_calls') and state['messages'][-1].tool_calls:
        return "continue_with_tool"
    return "finish"

# Bind our factual_qa tool to the LLM and force it to use a tool if possible
# This is required by our specific model
tools = [factual_qa]
llm_with_tools = llm.bind_tools(tools, tool_choice="any")

# Now, wire up the graph
workflow = StateGraph(AgentState)
workflow.add_node("agent", call_model)
workflow.add_node("action", ToolNode(tools)) # ToolNode is a pre-built node that runs our tools
workflow.set_entry_point("agent")
workflow.add_conditional_edges(
    "agent",
    should_continue,
    {"continue_with_tool": "action", "finish": END}
)
workflow.add_edge("action", "agent")

retrieval_agent = workflow.compile()

In [None]:
# Create the initial message for the agent
initial_message = HumanMessage(
    content=f"Here is my question: '{user_question}'\n\n"
            f"Here is the plan to follow:\n{initial_plan}"
)

# The input to the agent is always a list of messages
agent_input = {"messages": [initial_message]}

print("--- Running the full retrieval agent ---")

# Stream the agent's execution to see its thought process in real-time
async for output in retrieval_agent.astream(agent_input):
    for key, value in output.items():
        if key == "agent":
            agent_message = value['messages'][-1]
            if agent_message.tool_calls:
                print(f"LLM wants to call a tool: {agent_message.tool_calls[0]['name']}")
            else:
                print("\n--- AGENT: Final Answer ---")
                print(agent_message.content)
        elif key == "action":
            print("--- AGENT: Tool response received. ---")