# Build an Agent with GraphRAG and Tool Calling based on TiDB

![image.png](https://lab-static.pingcap.com/images/2025/8/28/de762b23b7be6dde39a9c4192bcbbbdc2ef6c63f.png)

## Prerequisites

### Install Dependencies

Install the Python SDK for TiDB

In [None]:
%pip install -q pytidb==0.0.9 pytidb[models]==0.0.9 openai yfinance

Install dependencies to build and visualize the knowledge graph

In [None]:
%pip install -q dspy pyvis dotenv matplotlib

### Configure Secrets

Before you begin, ensure you have an **OpenAI API key**. Please follow the [OpenAI platform](https://platform.openai.com/account/api-keys) to get one if you do not have one yet.

In [None]:
import os
from getpass import getpass
import dotenv

dotenv.load_dotenv()

if "OPENAI_API_KEY" not in os.environ:
    openai_api_key = getpass("OpenAI API Key:")
    os.environ['OPENAI_API_KEY'] = openai_api_key

## Prepare company data

### Step 1. Connect to TiDB

In [None]:
from pytidb import TiDBClient

db = TiDBClient.connect(
    host=os.getenv("SERVERLESS_CLUSTER_HOST"),
    port=int(os.getenv("SERVERLESS_CLUSTER_PORT")),
    username=os.getenv("SERVERLESS_CLUSTER_USERNAME"),
    password=os.getenv("SERVERLESS_CLUSTER_PASSWORD"),
    database=os.getenv("SERVERLESS_CLUSTER_DATABASE_NAME"),
    ensure_db=True
)

### Step 2. Define Embedding Function



PyTiDB provides an [automatic embedding](https://pingcap.github.io/ai/guides/auto-embedding/) feature. You can insert raw data, and pytidb will automatically generate embeddings based on the configuration and fill the corresponding vector fields.

To enable this feature, you first need to define an `EmbeddingFunction`.

In [None]:
from pytidb.embeddings import EmbeddingFunction

openai_embed = EmbeddingFunction(
    model_name="openai/text-embedding-3-small",
    api_key=os.getenv("OPENAI_API_KEY")
)
openai_embed.dimensions

### Step 3. Create data models and tables

**Drop the existing tables**

In [None]:
reset_tables = True
if reset_tables:
  db.execute("DROP TABLE IF EXISTS company_data;")
  db.execute("DROP TABLE IF EXISTS relationships;")
  db.execute("DROP TABLE IF EXISTS entities;")

**Create data models and tables:**

For pytidb, you can add a `VectorField` to store vector embeddings, and a `FullTextField` to store the searchable text.

> **NOTE:** The initial setup of vector and full-text indexes may take some time.

In [None]:
from typing import Optional, List, Dict, Any
from datetime import datetime

from pytidb.schema import TableModel, Field, FullTextField, Column, Relationship as SQLRelationship
from pytidb.datatype import JSON, DateTime, TEXT
from sqlalchemy.dialects.mysql import YEAR
from pytidb.sql import func


# Create a `company_data` table to store the basic information of the companies and their embeddings.
class CompanyData(TableModel):
    __tablename__ = "company_data"
    __table_args__ = {"extend_existing": "True"}

    vector_hash: str = Field(max_length=128, primary_key=True)
    company_name: Optional[str] = Field(max_length=255)
    name: Optional[str] = Field(max_length=255)
    stock_ticker: Optional[str] = Field(max_length=10)
    # 👇 Add a full-text field to store the text need to be searched.
    description: Optional[str] = FullTextField(sa_type=TEXT)
    industries: Optional[str] = Field(max_length=255)
    headquarters: Optional[str] = Field(sa_type=TEXT)
    specialties: Optional[str] = Field(sa_type=TEXT)
    employees: Optional[int]
    founded: Optional[int] = Field(sa_type=YEAR)
    linkedin_url: Optional[str] = Field(max_length=512)
    website_url: Optional[str] = Field(max_length=512)
    organization_type: Optional[str] = Field(max_length=64)
    company_size: Optional[int]
    followers: Optional[int]
    slogan: Optional[str]
    locations: Optional[Dict[str, Any]] = Field(sa_type=JSON)
    formatted_locations: Optional[Dict[str, Any]] = Field(sa_type=JSON)
    logo: Optional[str] = Field(max_length=512)
    image: Optional[str] = Field(max_length=512)
    country_code: Optional[str] = Field(max_length=10)
    funding: Optional[Dict[str, Any]] = Field(sa_type=JSON)
    investors: Optional[List[Dict[str, Any]]] = Field(sa_type=JSON)
    crunchbase_url: Optional[str]
    ai_summary: Optional[str] = Field(sa_type=TEXT)
    about: Optional[str] = Field(sa_type=TEXT)
    similar: Optional[str]
    # 👇 Add a vector field to store the vector embedding generated by OpenAI.
    embedding: Optional[List[float]] = openai_embed.VectorField(
        source_field="ai_summary",
    )
    created_at: Optional[datetime] = Field(
        sa_column=Column(
            DateTime(timezone=False), server_default=func.now(), nullable=True
        )
    )
    updated_at: Optional[datetime] = Field(
        sa_column=Column(
            DateTime(timezone=False), server_default=func.now(), onupdate=func.now(), nullable=True
        )
    )

company_table = db.create_table(schema=CompanyData, mode="exist_ok")


# Create `entities` table to store the entities extracted from the company information.
class DBEntity(TableModel):
    __tablename__ = "entities"
    __table_args__ = {"extend_existing": "True"}

    id: int = Field(default=None, primary_key=True)
    name: str = Field(default=None)
    description: str = Field(default=None, sa_type=TEXT)
    embedding: List[float] = openai_embed.VectorField(
        source_field="entity_str"
    )
    meta: dict = Field(sa_type=JSON, default_factory=dict)

    @property
    def entity_str(self):
        return f"{self.name}: {self.description}"

entity_table = db.create_table(schema=DBEntity, mode="exist_ok")


# Create `relationships` table to store the relationship between entities.
class DBRelationship(TableModel):
    __tablename__ = "relationships"
    __table_args__ = {"extend_existing": "True"}

    id: int = Field(default=None, primary_key=True)
    source_entity_id: int = Field(foreign_key="entities.id")
    target_entity_id: int = Field(foreign_key="entities.id")
    relationship_desc: str = Field(default=None, sa_type=TEXT)
    embedding: Optional[List[float]] = openai_embed.VectorField(
        source_field="relationship_desc"
    )
    source_entity: DBEntity = SQLRelationship(
        sa_relationship_kwargs={
            "primaryjoin": f"DBRelationship.source_entity_id == DBEntity.id",
            "lazy": "joined",
        },
    )
    target_entity: DBEntity = SQLRelationship(
        sa_relationship_kwargs={
            "primaryjoin": f"DBRelationship.target_entity_id == DBEntity.id",
            "lazy": "joined",
        },
    )

    @property
    def source_name(self) -> str:
        return self.source_entity.name

    @property
    def target_name(self) -> str:
        return self.target_entity.name

relationship_table = db.create_table(schema=DBRelationship, mode="exist_ok")

### Step 4. Insert sample company data

Before inserting sample data, you can clear all the existing data in the `company_data` table to make it clean.

In [None]:
company_table.delete("vector_hash is not null")

Bulk insert sample company data, the embedding field will be filled automatically.

In [None]:
company_table.bulk_insert([
    CompanyData(
        vector_hash="222f53a93827ece3d9a9cf0e918b112d495090a29e2a0248b6501a99bc4f6ffd",
        company_name="Apple Inc.",
        stock_ticker="AAPL",
        description="Apple Inc. is an American multinational technology company that specializes in consumer electronics, computer software, and online services. Apple is the world's largest technology company by revenue and one of the world's most valuable companies.",
        industries="Technology Hardware, Storage & Peripherals",
        headquarters="Cupertino, California, United States",
        specialties="iPhone, iPad, Mac, Apple Watch, AirPods, iOS, macOS, App Store, iCloud, Apple Music",
        employees=164000,
        founded=1976,
        linkedin_url="https://www.linkedin.com/company/apple/",
        website_url="https://www.apple.com",
        logo="https://www.apple.com/ac/structured-data/images/knowledge_graph_logo.png",
        ai_summary="Apple Inc. is a leading technology company founded in 1976 by Steve Jobs, Steve Wozniak, and Ronald Wayne. Headquartered in Cupertino, California, Apple designs, manufactures, and markets smartphones, personal computers, tablets, wearables, and accessories. The company is known for its innovative products including the iPhone, iPad, Mac computers, Apple Watch, and AirPods. Apple also provides digital services through the App Store, iCloud, Apple Music, and other platforms.",
    ),
    CompanyData(
        vector_hash="d77a83042ab8aa1a904359e7335bc6efa1d47bb38bab98b34ffc54a0e7fab16c",
        company_name="Microsoft Corporation",
        stock_ticker="MSFT",
        description="Microsoft Corporation is an American multinational technology company that produces computer software, consumer electronics, personal computers, and related services. Microsoft is known for its Windows operating system, Office productivity suite, and Azure cloud platform.",
        industries="Software—Infrastructure",
        headquarters="Redmond, Washington, United States",
        specialties="Windows, Microsoft Office, Azure, Xbox, Surface, Visual Studio, Teams, LinkedIn",
        employees=221000,
        founded=1975,
        linkedin_url="https://www.linkedin.com/company/microsoft/",
        website_url="https://www.microsoft.com",
        logo="https://img-prod-cms-rt-microsoft-com.akamaized.net/cms/api/am/imageFileData/RE1Mu3b",
        ai_summary="Microsoft Corporation is a multinational technology company founded in 1975 by Bill Gates and Paul Allen. Based in Redmond, Washington, Microsoft develops, manufactures, licenses, supports, and sells computer software, consumer electronics, personal computers, and related services. The company is best known for its Windows operating systems, Microsoft Office suite, and Azure cloud computing platform.",
    ),
    CompanyData(
        vector_hash="d745543f1fd7ec0985bded9d6800d10069280c0ad32c1a1f7f4c729f3352a8f5",
        company_name="Google (Alphabet Inc.)",
        stock_ticker="GOOGL",
        description="Alphabet Inc. is an American multinational technology conglomerate holding company headquartered in Mountain View, California. It was created through a restructuring of Google on October 2, 2015, and became the parent company of Google and several former Google subsidiaries.",
        industries="Internet Content & Information",
        headquarters="Mountain View, California, United States",
        specialties="Search, Advertising, Cloud Computing, YouTube, Android, Chrome, Google Workspace, AI, Machine Learning",
        employees=182502,
        founded=1998,
        linkedin_url="https://www.linkedin.com/company/google/",
        website_url="https://abc.xyz/",
        logo="https://www.google.com/images/branding/googlelogo/1x/googlelogo_color_272x92dp.png",
        ai_summary="Alphabet Inc., formerly known as Google, is a multinational technology conglomerate founded in 1998 by Larry Page and Sergey Brin. The company operates the world's most popular search engine and provides a wide range of internet-related services and products, including online advertising technologies, cloud computing, software, and hardware. Google's main products include Search, YouTube, Android, Chrome browser, and Google Cloud Platform.",
    ),
    CompanyData(
        vector_hash="d4e5f6789012345678901234567890abcdef1234567890abcdef1234567890abcd",
        company_name="Amazon.com Inc.",
        stock_ticker="AMZN",
        description="Amazon.com, Inc. is an American multinational technology company which focuses on e-commerce, cloud computing, digital streaming, and artificial intelligence. It has been referred to as one of the most influential economic and cultural forces in the world.",
        industries="Internet Retail",
        headquarters="Seattle, Washington, United States",
        specialties="E-commerce, Cloud Computing, Digital Streaming, Artificial Intelligence, AWS, Prime, Alexa, Logistics",
        employees=1541000,
        founded=1994,
        linkedin_url="https://www.linkedin.com/company/amazon/",
        website_url="https://www.amazon.com",
        logo="https://logo.clearbit.com/amazon.com",
        ai_summary="Amazon.com, Inc. is a multinational technology company founded in 1994 by Jeff Bezos. Starting as an online bookstore, Amazon has evolved into one of the world's largest e-commerce and cloud computing companies. The company operates through three main segments: North America, International, and Amazon Web Services (AWS). Amazon is known for its customer-centric approach, innovation in logistics, and cloud computing leadership through AWS."
    ),
    CompanyData(
        vector_hash="a9775c122d0ecae9b6d67d1b42ee71190fc2300ed0b94326f95ff452e7cc25bb",
        company_name="Oracle Corporation",
        stock_ticker="ORCL",
        description="Oracle Corporation is a multinational technology company that specializes in developing and marketing database software and technology, cloud engineered systems, and enterprise software products.",
        industries="Enterprise Software, Database Technology, Cloud Computing",
        headquarters="Austin, Texas, United States",
        specialties="Oracle Database, MySQL, Java, Cloud Infrastructure, Enterprise Applications, ERP Systems, CRM Solutions, Cloud Services",
        employees=164000,
        founded=1977,
        linkedin_url="https://www.linkedin.com/company/oracle/",
        website_url="https://www.oracle.com",
        logo="https://www.oracle.com/a/ocom/img/oracle-logo.svg",
        ai_summary = (
          "Oracle Corporation, founded in 1977 and based in Austin, Texas, "
          "is a global leader in database technology, cloud computing, and enterprise software. "
          "Its flagship product is Oracle Database. In 2009, Oracle acquired MySQL by purchasing Sun Microsystems. "
          "The company offers cloud infrastructure, ERP systems, CRM solutions, and various enterprise applications. "
          "Serving over 430,000 customers worldwide with around 164,000 employees, "
          "Oracle has been a dominant player in the enterprise software market for decades "
          "and continues to grow its cloud services to compete with major providers."
      )
    ),
    CompanyData(
        vector_hash="0b6949174c686059cdfc61a1b905a0eacca17f63f703f9ec097be07dda67c7d2",
        company_name="Tesla, Inc.",
        stock_ticker="TSLA",
        description="Tesla, Inc. is an American multinational automotive and clean energy company headquartered in Austin, Texas. Tesla designs and manufactures electric vehicles, battery energy storage systems, and solar panels.",
        industries="Auto Manufacturers",
        headquarters="Austin, Texas, United States",
        specialties="Electric Vehicles, Energy Storage, Solar Panels, Autonomous Driving, Supercharger Network, Model S, Model 3, Model X, Model Y",
        employees=140473,
        founded=2003,
        linkedin_url="https://www.linkedin.com/company/tesla-motors/",
        website_url="https://www.tesla.com",
        logo="https://logo.clearbit.com/tesla.com",
        ai_summary="Tesla, Inc. is an electric vehicle and clean energy company founded in 2003 by Martin Eberhard and Marc Tarpenning, with Elon Musk joining as chairman in 2004. Based in Austin, Texas, Tesla is known for its electric vehicles including the Model S, 3, X, and Y, as well as energy storage systems and solar panels. The company has revolutionized the automotive industry with its focus on sustainable transportation and autonomous driving technology."
    ),
    CompanyData(
        vector_hash="827a4366f31cdf14d67c8e85555fa15d4641654095e91a25f34434d712986d12",
        company_name="Meta Platforms, Inc.",
        stock_ticker="META",
        description="Meta Platforms, Inc., doing business as Meta and formerly named Facebook, Inc., is an American multinational technology conglomerate based in Menlo Park, California. The company owns and operates Facebook, Instagram, Threads, and WhatsApp.",
        industries="Internet Content & Information",
        headquarters="Menlo Park, California, United States",
        specialties="Social Media, Virtual Reality, Augmented Reality, Metaverse, Facebook, Instagram, WhatsApp, Oculus, AI Research",
        employees=86482,
        founded=2004,
        linkedin_url="https://www.linkedin.com/company/meta/",
        website_url="https://about.meta.com",
        logo="https://logo.clearbit.com/meta.com",
        ai_summary="Meta Platforms, Inc., formerly Facebook, Inc., is a multinational technology company founded in 2004 by Mark Zuckerberg. Based in Menlo Park, California, Meta operates the world's largest social media platforms including Facebook, Instagram, WhatsApp, and Threads. The company is also investing heavily in virtual and augmented reality technologies as part of its vision for the metaverse."
    ),
    CompanyData(
        vector_hash="59a4745cd42a0e5a08c234906eae37de8617743572054b36bf51ae8793f08d74",
        company_name="NVIDIA Corporation",
        stock_ticker="NVDA",
        description="NVIDIA Corporation is an American multinational technology company incorporated in Delaware and based in Santa Clara, California. It is a software and fabless company which designs graphics processing units (GPUs), application programming interfaces (APIs) for data science and high-performance computing.",
        industries="Semiconductors",
        headquarters="Santa Clara, California, United States",
        specialties="Graphics Processing Units, Artificial Intelligence, Data Centers, Gaming, Autonomous Vehicles, Deep Learning, CUDA",
        employees=29600,
        founded=1993,
        linkedin_url="https://www.linkedin.com/company/nvidia/",
        website_url="https://www.nvidia.com",
        logo="https://logo.clearbit.com/nvidia.com",
        ai_summary="NVIDIA Corporation is a multinational technology company founded in 1993 by Jensen Huang, Chris Malachowsky, and Curtis Priem. Based in Santa Clara, California, NVIDIA is a pioneer in GPU technology and has become a leader in artificial intelligence computing. The company's GPUs are widely used in gaming, data centers, and AI applications, making it one of the most valuable semiconductor companies in the world."
    ),
    CompanyData(
        vector_hash="83f2e90b806c2e162a9f5c8ad54e02712f3113d8ffd2312db01a4f709723f9dd",
        company_name="Netflix, Inc.",
        stock_ticker="NFLX",
        description="Netflix, Inc. is an American subscription streaming service and production company based in Los Gatos, California. Netflix was founded in 1997 by Reed Hastings and Marc Randolph and is now available in over 190 countries.",
        industries="Entertainment",
        headquarters="Los Gatos, California, United States",
        specialties="Streaming Video, Original Content Production, Technology Platform, Global Entertainment, Data Analytics",
        employees=13000,
        founded=1997,
        linkedin_url="https://www.linkedin.com/company/netflix/",
        website_url="https://www.netflix.com",
        logo="https://logo.clearbit.com/netflix.com",
        ai_summary="Netflix, Inc. is a streaming entertainment service founded in 1997 by Reed Hastings and Marc Randolph. Originally a DVD-by-mail service, Netflix transformed into a streaming platform and became a global leader in entertainment. The company is known for its original content production including series like Stranger Things, The Crown, and Squid Game, serving over 230 million subscribers worldwide."
    ),
    CompanyData(
        vector_hash="b516acc08247f7479faa864ad0138644a574f804061e73ed658cf46edb22f9d6",
        company_name="Intel Corporation",
        stock_ticker="INTC",
        description="Intel Corporation is an American multinational corporation and technology company headquartered in Santa Clara, California. Intel designs and manufactures microprocessors for computer systems and other electronic devices.",
        industries="Semiconductors",
        headquarters="Santa Clara, California, United States",
        specialties="Microprocessors, Semiconductors, Data Centers, AI Chips, IoT Solutions, Autonomous Driving, 5G Technology",
        employees=121100,
        founded=1968,
        linkedin_url="https://www.linkedin.com/company/intel-corporation/",
        website_url="https://www.intel.com",
        logo="https://logo.clearbit.com/intel.com",
        ai_summary="Intel Corporation is a multinational technology company founded in 1968 by Robert Noyce and Gordon Moore. Based in Santa Clara, California, Intel is one of the world's largest semiconductor companies, best known for its x86 microprocessors found in most personal computers. The company has been a pioneer in semiconductor manufacturing and continues to innovate in areas such as data centers, artificial intelligence, and autonomous driving technology."
    ),
    CompanyData(
        vector_hash="c9875d436bb12f89a3c7e8d456789012ab34567890cd567890ef1234567890ab",
        company_name="Salesforce, Inc.",
        stock_ticker="CRM",
        description="Salesforce, Inc. is an American cloud-based software company headquartered in San Francisco, California. It provides customer relationship management (CRM) software and applications focused on sales, customer service, marketing automation, e-commerce, analytics, and application development.",
        industries="Software—Application",
        headquarters="San Francisco, California, United States",
        specialties="CRM, Cloud Computing, Sales Automation, Marketing Automation, Customer Service, Analytics, AI Platform, Trailhead",
        employees=79390,
        founded=1999,
        linkedin_url="https://www.linkedin.com/company/salesforce/",
        website_url="https://www.salesforce.com",
        logo="https://logo.clearbit.com/salesforce.com",
        ai_summary="Salesforce, Inc. is a cloud-based software company founded in 1999 by Marc Benioff and Parker Harris. Based in San Francisco, California, Salesforce pioneered the Software-as-a-Service (SaaS) model and is the world's leading customer relationship management (CRM) platform. The company provides a comprehensive suite of cloud-based applications for sales, service, marketing, and more, serving businesses of all sizes globally."
    ),
])

print(f"Number of companies: {company_table.rows()}")

### Step 5. Perform vector search

In [None]:
vector_results = company_table.search("Companies in health industry", search_type="vector").limit(10).to_pandas()

vector_results[['vector_hash', 'company_name', 'description', 'embedding', '_distance', '_score']]

The `_distance` field in the search result represents the distance between the query and the matched document in the vector space. A smaller _distance means higher similarity and better relevance to your search.

For more details, see the [Vector Search](https://pingcap.github.io/ai/guides/vector-search/) section in the PyTiDB docs.

## Build GraphRAG Step-by-step



### Step 1. Setup OpenAI for dspy

Use Dspy as the AI framework to construct and manage prompts for the AI.

In [None]:
import dspy

open_ai = dspy.LM(
    model="gpt-4o",
    api_key=os.getenv("OPENAI_API_KEY"),
    max_tokens=4096
)
dspy.settings.configure(lm=open_ai)

### Step 2. Define Knowledge Graph Utils

#### extract_knowledge_graph

We use **dspy** to build a module that extracts key business entities and their relationships from company profile texts, automatically constructing a structured **business knowledge graph**. This helps capture companies, products, industries, customers, and ecosystem connections for better understanding of the business landscape.

In [None]:
from pydantic import BaseModel, Field
import dspy

class Entity(BaseModel):
    """List of entities extracted from the text to form the knowledge graph"""

    name: str = Field(
        description="Name of the entity, it should be a clear and concise term"
    )
    description: str = Field(
        description=(
            "Description of the entity, it should be a complete and comprehensive sentence, not few words. "
            "Sample description of entity 'PingCAP': "
            "'PingCAP is a leading distributed SQL database company founded in 2015, headquartered in Sunnyvale, California, and has over 500 employees.'"
        )
    )


class Relationship(BaseModel):
    """List of relationships extracted from the text to form the knowledge graph"""

    source_name: str = Field(
        description="Source entity name of the relationship, it should an existing entity in the Entity list"
    )
    target_name: str = Field(
        description="Target entity name of the relationship, it should an existing entity in the Entity list"
    )
    relationship_desc: str = Field(
        description=(
            "Description of the relationship, it should be a complete and comprehensive sentence, not few words. "
            "Sample description of relationship (source: TiDB, target: MySQL): 'TiDB is highly compatible with MySQL.'",
        )
    )

class KnowledgeGraph(BaseModel):
    """Graph representation of the knowledge for text."""

    entities: List[Entity] = Field(
        description="List of entities in the knowledge graph"
    )
    relationships: List[Relationship] = Field(
        description="List of relationships in the knowledge graph"
    )

class ExtractGraphTriplet(dspy.Signature):
    """Extract a comprehensive knowledge graph from LinkedIn company profiles by identifying business entities and their relationships.

## Entity Extraction Rules:
- Extract meaningful business entities: companies, products, technologies, services, locations, industries
- Use specific, self-explanatory names (avoid generic terms like "product", "technology")
- Consolidate similar entities to prevent redundancy
- Include both high-level concepts and technical details

Good entities: "TiDB Distributed Database", "MySQL Compatibility", "HTAP Architecture"
Bad entities: "AWS", "Product", "Technology", "111 Employees"

## Relationship Identification:
- Identify clear directional relationships between entities
- Focus on business dependencies, technical integrations, market competition
- Ensure each relationship has textual evidence

Common patterns:
- Company_A develops Product_B: for example, PingCAP develops TiDB Cloud
- Product_C compatible_with Standard_D
- Service_E targets Market_F
- Service_G based_on Product_H: for example. AWS RDS for MySQL is a fully managed service based on MySQL.

## Quality Requirements:
- Extract ALL meaningful entities and relationships in one pass
- Prioritize business-relevant information
- Maintain accurate relationship directionality
- Use specific, actionable relationship descriptions
- Ensure entity names are immediately understandable

Generate a complete knowledge graph that captures the company's business model, technical capabilities, and market positioning.
    """
    text = dspy.InputField(
        desc="A company introduction paragraph from LinkedIn or similar source"
    )
    knowledge: KnowledgeGraph = dspy.OutputField(
        desc="Graph representation of the business knowledge extracted from the text."
    )

class Extractor(dspy.Module):
    def __init__(self):
        super().__init__()
        self.prog_graph = dspy.Predict(ExtractGraphTriplet)

    def forward(self, text):
        return self.prog_graph(text=text)

kg_extractor = Extractor()

def extract_knowledge_graph(text) -> KnowledgeGraph:
    return kg_extractor(text=text).knowledge

#### save_knowledge_graph

In [None]:
def get_entities_by_names(names: List[str]) -> List[DBEntity]:
    if not names:
        return []
    return entity_table.query(filters={"name": {"$in": names}}).to_pydantic()

def find_most_similar_entity(entity: Entity) -> Optional[DBEntity]:
    query = f"{entity.name}: {entity.description}"
    results = entity_table.search(query).distance_threshold(0.3).limit(1).to_pydantic()
    return results[0] if results else None

def save_knowledge_graph(kg: KnowledgeGraph, metadata: Dict[str, Any] = {}):
    with db.session() as session:
        # Step 1: Find existing entities that have the exact same name.
        same_entities = get_entities_by_names([e.name for e in kg.entities])
        entity_name_to_id = {e.name: e.id for e in same_entities}

        # Step 2: For unmatched entities, try to find similar ones by semantic search.
        for entity in kg.entities:
            if entity.name in entity_name_to_id:
                continue  # Already matched by name.
            similar_entity = find_most_similar_entity(entity)
            if similar_entity:
                entity_name_to_id[entity.name] = similar_entity.id

        # Step 3: Insert the remaining entities that are completely new.
        entities_to_add = [
            DBEntity(name=e.name, description=e.description, meta=metadata)
            for e in kg.entities
            if e.name not in entity_name_to_id
        ]
        new_entities = entity_table.bulk_insert(entities_to_add)
        entity_name_to_id.update({e.name: e.id for e in new_entities})

        # Step 4: Insert relationships only if both source and target entities exist.
        relationships = [
            DBRelationship(
                source_entity_id=entity_name_to_id[r.source_name],
                target_entity_id=entity_name_to_id[r.target_name],
                relationship_desc=r.relationship_desc
            )
            for r in kg.relationships
            if r.source_name in entity_name_to_id and r.target_name in entity_name_to_id
        ]
        relationship_table.bulk_insert(relationships)

        session.commit()

#### retrieve_knowledge_graph

In [None]:
from collections import deque
from sqlalchemy import select, or_
from sqlalchemy.orm import joinedload


class RetrievedKnowledgeGraph(BaseModel):
    entities: List[DBEntity]
    relationships: List[DBRelationship]


def retrieve_knowledge_graph(query, max_depth=3, top_k=10) -> RetrievedKnowledgeGraph:
    with db.session() as session:
        start_entities = entity_table.search(query).limit(top_k).to_pydantic()
        if not start_entities:
            return [], []

        entities, relationships = knowledge_graph_bfs(session, start_entities, max_depth)
        return RetrievedKnowledgeGraph(
            entities=[DBEntity.model_validate(e) for e in entities],
            relationships=[DBRelationship.model_validate(r) for r in relationships],
        )


def get_connected_relationships(session, entity_id):
    stmt = (
        select(DBRelationship)
        .options(
            joinedload(DBRelationship.source_entity),
            joinedload(DBRelationship.target_entity),
        )
        .where(
            or_(
                DBRelationship.source_entity_id == entity_id,
                DBRelationship.target_entity_id == entity_id,
            )
        )
    )
    return session.execute(stmt).scalars().all()

def knowledge_graph_bfs(session, start_entities, max_depth: int):
    """
    BFS (Breadth-first search) from multiple entities, collecting all related entities and relationships.
    """
    visited_entities = {e.id: e for e in start_entities}
    visited_relationships = {}

    queue = deque((e, 0) for e in start_entities)

    while queue:
        entity, depth = queue.popleft()
        if depth >= max_depth:
            continue

        for rel in get_connected_relationships(session, entity.id):
            if rel.id in visited_relationships:
                continue
            visited_relationships[rel.id] = rel

            for neighbor in [rel.source_entity, rel.target_entity]:
                if neighbor.id not in visited_entities:
                    visited_entities[neighbor.id] = neighbor
                    queue.append((neighbor, depth + 1))

    return list(visited_entities.values()), list(visited_relationships.values())


#### visualize_knowledge_graph

**Define the structure can be visualizbled:**

In [None]:
from typing import Protocol

class VisualizableEntity(Protocol):
    name: str
    description: str
    meta: Dict[str, Any]

class VisualizableRelationship(Protocol):
    source_name: str
    target_name: str
    relationship_desc: str

class VisualizableKnowledgeGraph(Protocol):
    entities: List[VisualizableEntity]
    relationships: List[VisualizableRelationship]

**visualize_knowledge_graph**

In [None]:
from typing import Callable
from pyvis.network import Network
from IPython.display import display, HTML
import os
import time
import http.server
import socketserver
import threading
import webbrowser

def start_local_server(port=8000, default_page="company.html"):
    class MyHandler(http.server.SimpleHTTPRequestHandler):
        def do_GET(self):
            if self.path in ("/", "/index.html"):
                self.path = "/" + default_page
            return super().do_GET()

    with socketserver.TCPServer(("", port), MyHandler) as httpd:
        httpd.serve_forever()

def visualize_knowledge_graph(
    kg: VisualizableKnowledgeGraph,
    filename: str,
    port: int,
    custom_node_fn: Optional[Callable[[VisualizableEntity], dict]] = None
):
    net = Network(notebook=True, cdn_resources='remote')
    node_name_to_id = {e.name: i for i, e in enumerate(kg.entities)}

    for name, i in node_name_to_id.items():
        entity = kg.entities[i]
        node_properties = {}
        if custom_node_fn is not None:
            node_properties = custom_node_fn(entity)
        net.add_node(i, label=entity.name, title=entity.description, **node_properties)

    for r in kg.relationships:
        src = node_name_to_id.get(r.source_name)
        tgt = node_name_to_id.get(r.target_name)
        if src is not None and tgt is not None:
            net.add_edge(src, tgt, title=r.relationship_desc)

    if filename is None or filename == "":
        filename = f"graph_{time.time()}.html"

    net.save_graph(filename)

    server_thread = threading.Thread(
        target=start_local_server,
        args=(port, filename),
        daemon=True
    )
    server_thread.start()

    link_html = f'<a href="http://localhost:{port}/" target="_blank">Click here to open knowledge graph in new tab</a>'
    display(HTML(link_html))


**custom_node_fn**

A function to assign different colors to subgraphs of the knowledge graph based on company ID

In [None]:
COLORS = [
    "#5470C6", "#91CC75", "#FAC858", "#EE6666", "#73C0DE",
    "#3BA272", "#FC8452", "#9A60B4", "#EA7CCC",
    "#5A9BD5", "#A5A5A5", "#FFD700", "#FF6347", "#40E0D0",
    "#8A2BE2", "#FF69B4", "#00FA9A", "#D2691E", "#6495ED"
]

def get_company_color(company_id: str) -> str:
    index = hash(company_id) % len(COLORS)
    return COLORS[index]

def custom_node_fn(entity):
    meta = entity.meta if hasattr(entity, "meta") else {}
    company_id = meta.get("company_id", "default")
    color = get_company_color(company_id)
    return {
        "color": color
    }

#### answer_with_knowledge_graph

In [None]:
class AnswerWithKG(dspy.Signature):
    """
    Model signature for answering questions using a retrieved knowledge graph.

    Inputs:
        knowledge_graph_context (str): Formatted entities and relationships.
        question (str): Question to answer.

    Output:
        answer (str): Generated answer based on the knowledge graph and question,
                      formatted in Markdown.
    """
    knowledge_graph_context: str = dspy.InputField()
    question: str = dspy.InputField()
    answer: str = dspy.OutputField()


def answer_with_knowledge_graph(question: str) -> str:
    kg = retrieve_knowledge_graph(question, max_depth=3, top_k=10)

    entities_str = "\n".join(f"{e.name}: {e.description}" for e in kg.entities)
    relationships_str = "\n".join(f"{r.source_name} -> {r.relationship_desc} -> {r.target_name}" for r in kg.relationships)

    knowledge_graph_context = f"Entities:\n{entities_str}\n\nRelationships:\n{relationships_str}\n\n"

    answer_question = dspy.Predict(AnswerWithKG)
    return answer_question(knowledge_graph_context=knowledge_graph_context, question=question).answer

#### clear_knowledge_graph

In [None]:
from sqlalchemy import text

def clear_knowledge_graph():
  with db.session() as session:
      session.execute(text("SET FOREIGN_KEY_CHECKS = 0;"))
      relationship_table.truncate()
      entity_table.truncate()
      session.execute(text("SET FOREIGN_KEY_CHECKS = 1;"))

### Step3. Define Tool Calling Functions

#### Define Stock Price Query Tool

In [None]:
import yfinance as yf
from datetime import datetime, timedelta
import pandas as pd

def get_stock_price_history(ticker: str, period: str = "30d") -> str:
    """
    Fetch historical stock price data for a given ticker symbol.

    Args:
        ticker: Stock ticker symbol (e.g., 'AAPL', 'MSFT')
        period: Time period for historical data (e.g., '30d', '1y', '3mo')
                Valid periods: 1d,5d,1mo,3mo,6mo,1y,2y,5y,10y,ytd,max

    Returns:
        JSON string containing stock price information including:
        - Current price
        - Price change
        - Percentage change
        - Historical data summary
    """
    try:
        stock = yf.Ticker(ticker)

        # Get historical data
        hist = stock.history(period=period)
        if hist.empty:
            return f"No data found for ticker {ticker}"

        # Get current info
        info = stock.info
        current_price = info.get('currentPrice', hist['Close'].iloc[-1])

        # Calculate metrics
        start_price = hist['Close'].iloc[0]
        price_change = current_price - start_price
        pct_change = (price_change / start_price) * 100

        # Summary statistics
        max_price = hist['High'].max()
        min_price = hist['Low'].min()
        avg_volume = hist['Volume'].mean()

        result = {
            "ticker": ticker,
            "period": period,
            "current_price": round(current_price, 2),
            "start_price": round(start_price, 2),
            "price_change": round(price_change, 2),
            "percentage_change": round(pct_change, 2),
            "period_high": round(max_price, 2),
            "period_low": round(min_price, 2),
            "average_volume": int(avg_volume),
            "company_name": info.get('longName', 'N/A'),
            "currency": info.get('currency', 'USD'),
            "market_cap": info.get('marketCap', 'N/A'),
            "last_updated": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        }

        return str(result)

    except Exception as e:
        return f"Error fetching data for {ticker}: {str(e)}"

# Create DSPy Tool for stock price querying
stock_price_tool = dspy.Tool(
    func=get_stock_price_history,
    name="get_stock_price_history",
    desc="Retrieve historical stock price data and current financial metrics for a given ticker symbol"
)

#### Financial Analysis with Tool Calling

In [None]:
class FinancialAnalysisSignature(dspy.Signature):
    """
    Answer financial questions using knowledge graph context and available tools.
    You have access to a stock price tool that can fetch current and historical stock data.
    Use the tools when specific financial data or stock prices are requested.
    """
    question: str = dspy.InputField(desc="Financial question or analysis request")
    knowledge_graph_context: str = dspy.InputField(desc="Relevant entities and relationships from the knowledge graph")
    answer: str = dspy.OutputField(desc="Comprehensive financial analysis in Markdown format")

class FinancialAgent(dspy.Module):
    """Financial analysis agent with tool calling capabilities."""

    def __init__(self):
        super().__init__()
        self.react = dspy.ReAct(
            signature=FinancialAnalysisSignature,
            tools=[stock_price_tool],
            max_iters=3
        )

    def forward(self, question: str, knowledge_graph_context: str):
        return self.react(question=question, knowledge_graph_context=knowledge_graph_context)

def answer_with_knowledge_graph_and_tool_calling(question: str) -> str:
    # Retrieve relevant knowledge graph context
    kg = retrieve_knowledge_graph(question, max_depth=3, top_k=10)

    entities_str = "\n".join(f"{e.name}: {e.description}" for e in kg.entities)
    relationships_str = "\n".join(f"{r.source_name} -> {r.relationship_desc} -> {r.target_name}" for r in kg.relationships)

    knowledge_graph_context = f"Entities:\n{entities_str}\n\nRelationships:\n{relationships_str}\n\n"

    # Create financial agent with tool calling capability
    financial_agent = FinancialAgent()

    # Generate analysis with tool calling capability
    result = financial_agent(
        question=question,
        knowledge_graph_context=knowledge_graph_context
    )

    return result.answer

### Step 4. Conduct knowledge graph for a company

In [None]:
import pandas as pd

test_company = company_table.query(limit=1).to_pydantic()[0]
knowledge_graph = extract_knowledge_graph(test_company.ai_summary)

**The entities of generated knowledge graph:**

In [None]:
entities_df = pd.DataFrame([e.model_dump() for e in knowledge_graph.entities])
entities_df

**The relationships of generated knowledge graph:**

In [None]:
relationships_df = pd.DataFrame([e.model_dump() for e in knowledge_graph.relationships])
relationships_df

**Company Knowledge Graph Visualization:**

In [None]:
visualize_knowledge_graph(kg=knowledge_graph, filename="company.html", port=8000)

### Step 5. Build a cross-organizational knowledge graph



**Clear all knowledge graph data from database (if you want):**

In [None]:
clear_knowledge_graph()

**Build knowledge graph for companies:**

In [None]:
import logging
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed

companies = company_table.query(limit=20).to_pydantic()

def extract_company_knowledge_graph(company):
    try:
        knowledge_graph = extract_knowledge_graph(company.ai_summary)
        save_knowledge_graph(knowledge_graph, {
            "company_name": company.company_name,
            "company_id": company.vector_hash
        })
    except Exception as err:
        logging.error(f"Failed to process company ({company.company_name}): {err}")
        logging.error(traceback.format_exc())

# Build knowledge graph for companies concurrently
with ThreadPoolExecutor(max_workers=5) as executor:
    futures = [executor.submit(extract_company_knowledge_graph, company) for company in companies]
    for future in as_completed(futures):
        try:
            future.result()
        except Exception as err:
            logging.error(f"Future failed: {err}")


**Number of entities:**

In [None]:
entity_table.rows()

**Number of relationships:**

In [None]:
relationship_table.rows()

### Step 6. Retrieve knowledge graph

**Input a question:**

In [None]:
question = "Analyze Tesla's stock price trend over the last 30 days in context of their EV market position"   # @param {type:"string"}

In [None]:
knowledge_graph = retrieve_knowledge_graph(question, max_depth=3, top_k=10)
visualize_knowledge_graph(kg=knowledge_graph, filename="whole_graph.html", port=8004, custom_node_fn=custom_node_fn)

### Step 7. Chat with Knowledge Graph

In [None]:
from IPython.display import display, Markdown
answer = answer_with_knowledge_graph(question)
display(Markdown(answer))

### Step 8. Financial Analysis with Tool Calling

In [None]:
analysis = answer_with_knowledge_graph_and_tool_calling(question)
display(Markdown(analysis))

### Step 8. Review the Token Usage of LLM

In [None]:
api_call_amount = len(open_ai.history)
print(f"In this demo, we called LLM API of OpenAI {api_call_amount} times.")

#### Token Consumption

In [None]:
import pandas as pd

hist = open_ai.history
df = pd.DataFrame([
    {
        'prompt_tokens': h['usage'].get('prompt_tokens'),
        'completion_tokens': h['usage'].get('completion_tokens'),
        'total_tokens': h['usage'].get('total_tokens')
    }
    for h in hist
])

totals = df.sum()
print(f"Total prompt_tokens: {totals['prompt_tokens']}")
print(f"Total completion_tokens: {totals['completion_tokens']}")
print(f"Total total_tokens: {totals['total_tokens']}")

#### Cumulative Token Usage Breakdown

In [None]:
import matplotlib.pyplot as plt

labels = ['Prompt tokens', 'Completion tokens']
sizes = [totals['prompt_tokens'], totals['completion_tokens']]

plt.figure(figsize=(6, 6))
plt.pie(sizes, labels=labels, autopct='%1.1f%%', startangle=140)
plt.title('Cumulative Token Usage Breakdown')
plt.axis('equal')
plt.show()

#### Per‑Call Token Usage Comparison

In [None]:
plt.figure(figsize=(10, 6))
x = df.index.astype(str)

bar_width = 0.25
r1 = range(len(df))
r2 = [i + bar_width for i in r1]
r3 = [i + 2*bar_width for i in r1]

plt.bar(r1, df['prompt_tokens'], width=bar_width, label='Prompt tokens')
plt.bar(r2, df['completion_tokens'], width=bar_width, label='Completion tokens')
plt.bar(r3, df['total_tokens'], width=bar_width, label='Total tokens')

plt.xlabel('Calling Index')
plt.ylabel('Token Consumption')
plt.title('Per‑Call Token Usage Comparison')
plt.xticks([r + bar_width for r in r1], x, rotation=45)
plt.legend()
plt.tight_layout()
plt.show()