In [1]:
import os, uuid, json, time
import pandas as pd
from tqdm import tqdm

from google import genai
from google.genai import types
from google.genai.errors import ServerError

from typing import List, Optional, Any, Dict
from pydantic import BaseModel, Field

In [2]:
client = genai.Client()

MODEL = "gemini-2.5-flash"
TEMPERATURE = 0.0

In [3]:
call_section = pd.read_parquet("data/corpus/tech_call_sections.parquet")
print(call_section.shape)
call_section.head(1)

(10255, 21)


Unnamed: 0,companyid,company_name,market_cap,transcriptid,headline,mostimportantdateutc,mostimportanttimeutc,keydeveventtypeid,keydeveventtypename,transcriptcollectiontypeid,...,transcriptpersonname,speakertypename,componentorder,componenttext,segment_id,segment_idx,call_year,call_quarter,call_period,clean_text
0,19691.0,CISCO SYSTEMS INC,269596.8,3372444.0,"Cisco Systems, Inc., Q2 2025 Earnings Call, Fe...",2025-02-12,21:30:00,48.0,Earnings Calls,7,...,Operator,Operator,0,[Audio Gap],3372444.0_0,0,2025,Q1,2025 Q1,[Audio Gap]


In [4]:
call_section.loc[:,"statement_id"] = call_section["segment_id"].astype(str)
if "clean_text" not in call_section.columns: call_section["clean_text"] = call_section["componenttext"].fillna("").astype(str)

In [42]:
from pydantic import BaseModel
from typing import Optional

class StatementRow(BaseModel):
    statement_id: str
    companyid: str
    company_name: str
    call_period: str
    transcriptpersonname: str
    speakertypename: str
    transcriptcomponenttypename: str
    clean_text: str
    segment_id: Optional[str] = None
    segment_idx: Optional[int] = None
    call_id: Optional[str] = None
    call_year: Optional[int] = None
    call_quarter: Optional[int] = None
    mostimportantdateutc: Optional[str] = None

In [43]:
from typing import List, Optional, Literal
from pydantic import BaseModel, Field

MetricDirection = Literal["up", "down", "flat", "mixed", "unknown"]
MetricValueType = Literal["level", "change_abs", "change_pct", "ratio", "other"]
MetricCertainty = Literal["explicit", "implicit", "uncertain"]

RiskSentiment = Literal["negative", "neutral", "positive", "mixed", "unknown"]
RiskSeverity = Literal["low", "medium", "high", "unknown"]

SegmentDirection = Literal["up", "down", "flat", "mixed", "unknown"]
SegmentCertainty = Literal["explicit", "implicit", "uncertain"]

OverallSentiment = Literal["positive", "negative", "neutral", "mixed", "unknown"]


class MetricItem(BaseModel):
    name: str = Field(
        ...,
        description="Short metric name, e.g. 'revenue', 'EPS', 'operating margin', 'cloud ARR'."
    )
    category: Optional[str] = Field(
        default=None,
        description="Category of metric, e.g. 'financial', 'operational', 'user', 'margin', 'cashflow', 'other'."
    )
    value: Optional[float] = Field(
        default=None,
        description="Numeric value if explicitly mentioned; null if no clear numeric value."
    )
    value_type: MetricValueType = Field(
        default="level",
        description="Type of value: 'level' (absolute value), 'change_abs', 'change_pct', 'ratio', or 'other'."
    )
    unit: Optional[str] = Field(
        default=None,
        description="Unit of the value, e.g. 'million', 'billion', 'percent', 'bps', 'users', or null."
    )
    currency: Optional[str] = Field(
        default=None,
        description="Currency code if relevant, e.g. 'USD', 'EUR', or null."
    )
    direction: MetricDirection = Field(
        default="unknown",
        description="Direction of change compared to prior period: 'up', 'down', 'flat', 'mixed', or 'unknown'."
    )
    is_guidance: Optional[bool] = Field(
        default=None,
        description="true if clearly forward-looking guidance, false if clearly realized/historical, null if unclear."
    )
    period: Optional[str] = Field(
        default=None,
        description="Time period for this metric, e.g. 'Q3 2024', 'FY24', 'next quarter', 'next fiscal year', or null."
    )
    certainty: MetricCertainty = Field(
        default="explicit",
        description="Was this metric explicitly stated ('explicit'), inferred from language ('implicit'), or uncertain?"
    )
    evidence_span: Optional[str] = Field(
        default=None,
        description="Short snippet or phrase from the statement that supports this metric mention."
    )
    
    context: Optional[str] = Field(
        default=None,
        description="Short context for the metric, e.g. 'North America cloud segment', 'organic growth', etc."
    )


class RiskItem(BaseModel):
    # REQUIRED: short risk label
    type: str = Field(
        ...,
        description="Short risk label, e.g. 'FX', 'macro', 'regulation', 'competition', 'supply chain', 'execution'."
    )
    sentiment: RiskSentiment = Field(
        default="unknown",
        description="Sentiment associated with this risk: 'negative', 'neutral', 'positive', 'mixed', or 'unknown'."
    )
    severity: RiskSeverity = Field(
        default="unknown",
        description="Estimated severity of the risk if mentioned: 'low', 'medium', 'high', or 'unknown'."
    )
    certainty: SegmentCertainty = Field(
        default="explicit",
        description="Was this risk explicitly mentioned, implicitly implied, or uncertain?"
    )
    evidence_span: Optional[str] = Field(
        default=None,
        description="Short snippet or phrase from the statement that describes this risk."
    )
    context: Optional[str] = Field(
        default=None,
        description="Short context for the risk, e.g. 'Europe', 'consumer hardware', 'China regulation'."
    )


class SegmentItem(BaseModel):
    # REQUIRED: segment/product/region name
    name: str = Field(
        ...,
        description="Name of the business segment, product, or region, e.g. 'cloud', 'PC', 'APAC', 'gaming'."
    )
    direction: SegmentDirection = Field(
        default="unknown",
        description="Direction of performance for this segment: 'up', 'down', 'flat', 'mixed', or 'unknown'."
    )
    is_guidance: Optional[bool] = Field(
        default=None,
        description="true if the statement is about future performance for this segment, false if historical, null if unclear."
    )
    certainty: SegmentCertainty = Field(
        default="explicit",
        description="Was this segment performance explicitly described, implicitly implied, or uncertain?"
    )
    evidence_span: Optional[str] = Field(
        default=None,
        description="Short snippet or phrase from the statement that mentions this segment."
    )
    context: Optional[str] = Field(
        default=None,
        description="Short context or qualifier, e.g. 'enterprise customers', 'US market', 'SMB'."
    )


class ExtractionResponse(BaseModel):
    metrics: List[MetricItem] = Field(default_factory=list)
    risks: List[RiskItem] = Field(default_factory=list)
    segments: List[SegmentItem] = Field(default_factory=list)

    overall_sentiment: OverallSentiment = Field(
        default="unknown",
        description="Overall sentiment of the statement about the company's performance."
    )


In [6]:
RESPONSE_SCHEMA = ExtractionResponse.model_json_schema()

In [7]:
SYSTEM_PROMPT = """
You are an expert financial information extraction model.
You read earnings call statements and extract metrics, guidance, risks, and business segments.
Always return valid JSON following the exact response_schema.
"""

In [8]:
def call_gemini(prompt: str, max_retries: int = 3, retry_delay: float = 1.0):
    attempts = 0

    while attempts <= max_retries:
        try:
            response = client.models.generate_content(
                model=MODEL,
                contents=prompt,
                config=types.GenerateContentConfig(
                    temperature=TEMPERATURE,
                    system_instruction=SYSTEM_PROMPT,
                    # max_output_tokens=MAX_OUTPUT_TOKENS,
                    response_mime_type="application/json",
                    response_schema=RESPONSE_SCHEMA
                )
            )
            # print(response.text)
            return response.text

        except ServerError as e:
            err_code = None

            if e.args:
                first_arg = e.args[0]
                if isinstance(first_arg, dict) and "error" in first_arg:
                    err_code = first_arg["error"].get("code")
                elif isinstance(first_arg, str) and "503" in first_arg:
                    err_code = 503

            # retry on 503
            if err_code == 503:
                attempts += 1
                if attempts > max_retries:
                    print("Max retries exceeded. Raising error.")
                    raise

                print(f"Model overloaded (503). Retrying {attempts}/{max_retries}...")
                time.sleep(retry_delay)
                continue

            raise
    
    # shouldn't really get here
    raise RuntimeError("Unexpected failure in call_gemini retry loop.")


In [9]:
def build_prompt(row: pd.Series) -> str:
    prompt = f"""
You will extract structured financial knowledge from the following earnings call statement.

Rules:
- Only extract metrics, risks, and segments that are clearly supported by the text.
- If no numeric value is explicitly mentioned, set 'value' to null but still fill 'name', 'direction', etc. if meaningful.
- Use the enums exactly as defined in the schema. Do NOT invent new values.
- For each item, include a short 'evidence_span' copied from the statement.
- If you are not sure, prefer 'unknown' for enums and null for free-text fields.

Statement metadata:
Company: {row['company_name']}
Company ID: {row['companyid']}
Call Period: {row['call_period']}
Speaker: {row['transcriptpersonname']} ({row['speakertypename']})
Segment Type: {row['transcriptcomponenttypename']}

Statement text:
\"\"\"{row['clean_text']}\"\"\"
"""
    return prompt

In [10]:
aapl_q4_2025 = call_section[
    (call_section["companyid"] == 24937.0) &
    (call_section["call_year"] == 2025) &
    (call_section["call_quarter"] == "Q4")
].copy()

In [11]:
aapl_q4_2025

Unnamed: 0,companyid,company_name,market_cap,transcriptid,headline,mostimportantdateutc,mostimportanttimeutc,keydeveventtypeid,keydeveventtypename,transcriptcollectiontypeid,...,speakertypename,componentorder,componenttext,segment_id,segment_idx,call_year,call_quarter,call_period,clean_text,statement_id
2592,24937.0,APPLE INC,3761715.1938,3573894.0,"Apple Inc., Q4 2025 Earnings Call, Oct 30, 2025",2025-10-30,21:00:00,48.0,Earnings Calls,7,...,Executives,0,"Good afternoon, and welcome to the Apple Q4 Fi...",3573894.0_0,2592,2025,Q4,2025 Q4,"Good afternoon, and welcome to the Apple Q4 Fi...",3573894.0_0
2593,24937.0,APPLE INC,3761715.1938,3573894.0,"Apple Inc., Q4 2025 Earnings Call, Oct 30, 2025",2025-10-30,21:00:00,48.0,Earnings Calls,7,...,Executives,1,"Thank you, Suhasini. Good afternoon, everyone,...",3573894.0_1,2593,2025,Q4,2025 Q4,"Thank you, Suhasini. Good afternoon, everyone,...",3573894.0_1
2594,24937.0,APPLE INC,3761715.1938,3573894.0,"Apple Inc., Q4 2025 Earnings Call, Oct 30, 2025",2025-10-30,21:00:00,48.0,Earnings Calls,7,...,Executives,2,"Thanks, Tim, and good afternoon, everyone. Our...",3573894.0_2,2594,2025,Q4,2025 Q4,"Thanks, Tim, and good afternoon, everyone. Our...",3573894.0_2
2595,24937.0,APPLE INC,3761715.1938,3573894.0,"Apple Inc., Q4 2025 Earnings Call, Oct 30, 2025",2025-10-30,21:00:00,48.0,Earnings Calls,7,...,Executives,3,"[Operator Instructions] Operator, may we have...",3573894.0_3,2595,2025,Q4,2025 Q4,"[Operator Instructions] Operator, may we have ...",3573894.0_3
2596,24937.0,APPLE INC,3761715.1938,3573894.0,"Apple Inc., Q4 2025 Earnings Call, Oct 30, 2025",2025-10-30,21:00:00,48.0,Earnings Calls,7,...,Operator,4,"Certainly, we will go ahead and take our first...",3573894.0_4,2596,2025,Q4,2025 Q4,"Certainly, we will go ahead and take our first...",3573894.0_4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2800,24937.0,APPLE INC,3761715.1938,3580091.0,"Apple Inc., Q4 2025 Earnings Call, Oct 30, 2025",2025-10-30,21:00:00,48.0,Earnings Calls,1,...,Executives,64,I think that there are many factors that influ...,3580091.0_64,2800,2025,Q4,2025 Q4,I think that there are many factors that influ...,3580091.0_64
2801,24937.0,APPLE INC,3761715.1938,3580091.0,"Apple Inc., Q4 2025 Earnings Call, Oct 30, 2025",2025-10-30,21:00:00,48.0,Earnings Calls,1,...,Analysts,65,Okay. And then one for Kevan. In the wake of n...,3580091.0_65,2801,2025,Q4,2025 Q4,Okay. And then one for Kevan. In the wake of n...,3580091.0_65
2802,24937.0,APPLE INC,3761715.1938,3580091.0,"Apple Inc., Q4 2025 Earnings Call, Oct 30, 2025",2025-10-30,21:00:00,48.0,Earnings Calls,1,...,Executives,66,"Richard, thanks for the question. In general, ...",3580091.0_66,2802,2025,Q4,2025 Q4,"Richard, thanks for the question. In general, ...",3580091.0_66
2803,24937.0,APPLE INC,3761715.1938,3580091.0,"Apple Inc., Q4 2025 Earnings Call, Oct 30, 2025",2025-10-30,21:00:00,48.0,Earnings Calls,1,...,Executives,67,A replay of today's call will be available for...,3580091.0_67,2803,2025,Q4,2025 Q4,A replay of today's call will be available for...,3580091.0_67


In [None]:
# TODO: convert to using agentics for this, can scale since parallel etc...

metric_rows = []
risk_rows = []
segment_rows = []
statement_rows = []

# for _, row in tqdm(call_section.iterrows(), total=len(call_section)):
# for _, row in tqdm(call_section.loc[[10]].iterrows(), total=1):
for _, row in tqdm(aapl_q4_2025.iterrows(), total=len(aapl_q4_2025)):
    statement_id = row["statement_id"]
    company_id = row["companyid"]
    call_id = row["transcriptid"]
    try:
        prompt = build_prompt(row)
        response_text = call_gemini(prompt)

        data = ExtractionResponse.model_validate_json(response_text)

        statement_rows.append({
            "statement_id": statement_id,
            "segment_id": row["segment_id"],
            "segment_idx": row["segment_idx"],
            "company_id": company_id,
            "company_name": row["company_name"],
            "call_id": call_id,
            "call_date": row["mostimportantdateutc"],
            "call_period": row["call_period"],
            "call_year": row["call_year"],
            "call_quarter": row["call_quarter"],
            "speaker_name": row["transcriptpersonname"],
            "speaker_role": row["speakertypename"],
            "segment_type": row["transcriptcomponenttypename"],
            "text": row["clean_text"],
            "overall_sentiment": data.overall_sentiment,
        })

        for m in data.metrics:
            metric_rows.append({
                "fact_id": str(uuid.uuid4()),
                "statement_id": statement_id,
                "company_id": company_id,
                "call_id": call_id,
                "call_period": row["call_period"],
                "metric_name": m.name,
                "metric_category": m.category,
                "metric_value": m.value,
                "metric_value_type": m.value_type,      # "level", "change_abs", "change_pct", "ratio", "other"
                "metric_unit": m.unit,
                "metric_currency": m.currency,
                "metric_direction": m.direction,        # "up", "down", "flat", "mixed", "unknown"
                "metric_is_guidance": m.is_guidance,
                "metric_period": m.period,
                "metric_certainty": m.certainty,        # "explicit", "implicit", "uncertain"
                "metric_evidence_span": m.evidence_span,
                "metric_context": m.context,
            })

        for r in data.risks:
            risk_rows.append({
                "fact_id": str(uuid.uuid4()),
                "statement_id": statement_id,
                "company_id": company_id,
                "call_id": call_id,
                "call_period": row["call_period"],
                "risk_type": r.type,
                "risk_sentiment": r.sentiment,          # "negative", "neutral", "positive", "mixed", "unknown"
                "risk_severity": r.severity,            # "low", "medium", "high", "unknown"
                "risk_certainty": r.certainty,          # "explicit", "implicit", "uncertain"
                "risk_evidence_span": r.evidence_span,
                "risk_context": r.context,
            })

        for s in data.segments:
            segment_rows.append({
                "fact_id": str(uuid.uuid4()),
                "statement_id": statement_id,
                "company_id": company_id,
                "call_id": call_id,
                "call_period": row["call_period"],
                "segment_name": s.name,
                "segment_direction": s.direction,       # "up", "down", "flat", "mixed", "unknown"
                "segment_is_guidance": s.is_guidance,
                "segment_certainty": s.certainty,       # "explicit", "implicit", "uncertain"
                "segment_evidence_span": s.evidence_span,
                "segment_context": s.context,
            })
        time.sleep(2)
    except:
        print(f"[ERROR] Statement {statement_id} on Call {call_id} for Company {company_id}")
        continue


100%|██████████| 213/213 [32:15<00:00,  9.09s/it] 


In [13]:
statements_df    = pd.DataFrame(statement_rows)
metric_facts_df  = pd.DataFrame(metric_rows)
risk_facts_df    = pd.DataFrame(risk_rows)
segment_facts_df = pd.DataFrame(segment_rows)

In [14]:
statements_df.to_parquet('data/processed/statements.parquet')
metric_facts_df.to_parquet('data/processed/metrics.parquet')
risk_facts_df.to_parquet('data/processed/risks.parquet')
segment_facts_df.to_parquet('data/processed/segments.parquet')

In [19]:
from neo4j import GraphDatabase

URI  = "neo4j+s://044150d7.databases.neo4j.io"
AUTH = ("neo4j", "vp797AtAjt_iUqdjAJAMT7cXH7z_5pyQNEavpMvm1MU")

def create_constraints():
    with driver.session() as session:
        session.run("""
            CREATE CONSTRAINT company_id_unique IF NOT EXISTS
            FOR (c:Company) REQUIRE c.company_id IS UNIQUE
        """)
        session.run("""
            CREATE CONSTRAINT call_id_unique IF NOT EXISTS
            FOR (call:Call) REQUIRE call.call_id IS UNIQUE
        """)
        session.run("""
            CREATE CONSTRAINT statement_id_unique IF NOT EXISTS
            FOR (s:Statement) REQUIRE s.statement_id IS UNIQUE
        """)
        session.run("""
            CREATE CONSTRAINT metric_name_unique IF NOT EXISTS
            FOR (m:Metric) REQUIRE m.name IS UNIQUE
        """)
        session.run("""
            CREATE CONSTRAINT risk_type_unique IF NOT EXISTS
            FOR (r:Risk) REQUIRE r.type IS UNIQUE
        """)
        session.run("""
            CREATE CONSTRAINT segment_name_unique IF NOT EXISTS
            FOR (seg:Segment) REQUIRE seg.name IS UNIQUE
        """)

with GraphDatabase.driver(URI, auth=AUTH) as driver:
    driver.verify_connectivity()
    print("Connection established.")
    create_constraints()
    print("Constraints added.")

Connection established.
Constraints added.


In [40]:
statements_df.columns

Index(['statement_id', 'segment_id', 'segment_idx', 'company_id',
       'company_name', 'call_id', 'call_period', 'call_year', 'call_quarter',
       'speaker_name', 'speaker_role', 'segment_type', 'text',
       'overall_sentiment'],
      dtype='object')

In [41]:
from tqdm import tqdm
def ingest_statements(statements_df):
    with driver.session() as session:
        for _, row in tqdm(statements_df.iterrows(), total=len(statements_df)):
            session.run(
                """
                MERGE (c:Company {company_id: $company_id})
                  ON CREATE SET c.name = $company_name

                MERGE (call:Call {call_id: $call_id})
                  ON CREATE SET
                    call.period        = $call_period,
                    call.call_date     = $call_date,
                    call.fiscal_year   = $fiscal_year,
                    call.fiscal_quarter= $fiscal_quarter

                MERGE (s:Statement {statement_id: $statement_id})
                  ON CREATE SET 
                    s.text             = $text,
                    s.overall_sentiment= $overall_sentiment,
                    s.call_period      = $call_period

                MERGE (c)-[:HAS_CALL]->(call)
                MERGE (call)-[:HAS_STATEMENT]->(s)
                """,
                {
                    "company_id":        row["company_id"],
                    "company_name":      row["company_name"],
                    "call_id":           row["call_id"],
                    "call_period":       row["call_period"],
                    "call_date":         row.get("call_date"),
                    "call_year":       row.get("call_year"),
                    "call_quarter":    row.get("call_quarter"),
                    "statement_id":      row["statement_id"],
                    "text":              row["clean_text"],
                    "overall_sentiment": row["overall_sentiment"],
                },
            )

In [None]:
with GraphDatabase.driver(URI, auth=AUTH) as driver:
    ingest_statements(statements_df)
    print("Statements ingested.")

100%|██████████| 213/213 [00:12<00:00, 16.58it/s]


Statements ingested.


In [21]:
def ingest_metrics(metric_facts_df):
    with driver.session() as session:
        for _, row in tqdm(metric_facts_df.iterrows(), total=len(metric_facts_df)):
            session.run(
                """
                MATCH (s:Statement {statement_id: $statement_id})
                MERGE (m:Metric {name: $metric_name})
                MERGE (s)-[rel:MENTIONS_METRIC]->(m)
                  ON CREATE SET
                    rel.value          = $metric_value,
                    rel.value_type     = $metric_value_type,
                    rel.unit           = $metric_unit,
                    rel.currency       = $metric_currency,
                    rel.direction      = $metric_direction,
                    rel.is_guidance    = $metric_is_guidance,
                    rel.metric_period  = $metric_period,
                    rel.certainty      = $metric_certainty,
                    rel.evidence_span  = $metric_evidence_span,
                    rel.context        = $metric_context
                """,
                {
                    "statement_id":        row["statement_id"],
                    "metric_name":         row["metric_name"],
                    "metric_value":        row.get("metric_value"),
                    "metric_value_type":   row.get("metric_value_type"),
                    "metric_unit":         row.get("metric_unit"),
                    "metric_currency":     row.get("metric_currency"),
                    "metric_direction":    row.get("metric_direction"),
                    "metric_is_guidance":  row.get("metric_is_guidance"),
                    "metric_period":       row.get("metric_period"),
                    "metric_certainty":    row.get("metric_certainty"),
                    "metric_evidence_span":row.get("metric_evidence_span"),
                    "metric_context":      row.get("metric_context"),
                },
            )
with GraphDatabase.driver(URI, auth=AUTH) as driver:
    ingest_metrics(metric_facts_df)
    print("Metrics ingested")

100%|██████████| 482/482 [00:33<00:00, 14.22it/s]

Metrics ingested





In [22]:
def ingest_risks(risk_facts_df):
    with driver.session() as session:
        for _, row in tqdm(risk_facts_df.iterrows(), total=len(risk_facts_df)):
            session.run(
                """
                MATCH (s:Statement {statement_id: $statement_id})
                MERGE (r:Risk {type: $risk_type})
                MERGE (s)-[rel:MENTIONS_RISK]->(r)
                  ON CREATE SET
                    rel.sentiment     = $risk_sentiment,
                    rel.severity      = $risk_severity,
                    rel.certainty     = $risk_certainty,
                    rel.evidence_span = $risk_evidence_span,
                    rel.context       = $risk_context
                """,
                {
                    "statement_id":        row["statement_id"],
                    "risk_type":           row["risk_type"],
                    "risk_sentiment":      row.get("risk_sentiment"),
                    "risk_severity":       row.get("risk_severity"),
                    "risk_certainty":      row.get("risk_certainty"),
                    "risk_evidence_span":  row.get("risk_evidence_span"),
                    "risk_context":        row.get("risk_context"),
                },
            )

with GraphDatabase.driver(URI, auth=AUTH) as driver:
    ingest_risks(risk_facts_df)
    print("Risks ingested")

100%|██████████| 112/112 [00:07<00:00, 14.26it/s]

Risks ingested





In [23]:
def ingest_segments(segment_facts_df):
    with driver.session() as session:
        for _, row in tqdm(segment_facts_df.iterrows(), total=len(segment_facts_df)):
            session.run(
                """
                MATCH (s:Statement {statement_id: $statement_id})
                MERGE (seg:Segment {name: $segment_name})
                MERGE (s)-[rel:MENTIONS_SEGMENT]->(seg)
                  ON CREATE SET
                    rel.direction      = $segment_direction,
                    rel.is_guidance    = $segment_is_guidance,
                    rel.certainty      = $segment_certainty,
                    rel.evidence_span  = $segment_evidence_span,
                    rel.context        = $segment_context
                """,
                {
                    "statement_id":         row["statement_id"],
                    "segment_name":         row["segment_name"],
                    "segment_direction":    row.get("segment_direction"),
                    "segment_is_guidance":  row.get("segment_is_guidance"),
                    "segment_certainty":    row.get("segment_certainty"),
                    "segment_evidence_span":row.get("segment_evidence_span"),
                    "segment_context":      row.get("segment_context"),
                },
            )

with GraphDatabase.driver(URI, auth=AUTH) as driver:
    ingest_segments(segment_facts_df)
    print("Segments ingested.")

100%|██████████| 434/434 [00:24<00:00, 17.63it/s]

Segments ingested.





In [None]:
import numpy as np

EMBEDDING_MODEL = "text-embedding-004"
def embed_text(text: str) -> np.ndarray:
    resp = client.models.embed_content(model=EMBEDDING_MODEL, contents=[text])
    vec = resp.embeddings[0].values
    v = np.array(vec, dtype="float32")
    v = v / max(np.linalg.norm(v), 1e-12)  # L2 normalize
    return v

def embed_batch(texts: List[str]) -> List[List[float]]:
    return [embed_text(t).tolist() for t in texts]

def embed_query(query: str) -> list[float]:
    return embed_text(query).tolist()

In [None]:
texts = statements_df["text"].tolist()
ids   = statements_df["statement_id"].tolist()
embs  = embed_batch(texts)

In [27]:
with GraphDatabase.driver(URI, auth=AUTH) as driver:
    with driver.session() as session:
        for sid, vec in tqdm(list(zip(ids, embs))):
            session.run(
                """
                MATCH (s:Statement {statement_id: $sid})
                SET s.embedding = $vec
                """,
                {"sid": sid, "vec": vec},
            )

100%|██████████| 213/213 [00:15<00:00, 13.48it/s]


In [30]:
EMBED_DIM = len(embs[0])  # e.g. 768

with GraphDatabase.driver(URI, auth=AUTH) as driver:
    with driver.session() as session: 
       session.run(f"""
                   CREATE VECTOR INDEX statement_embedding_index IF NOT EXISTS
                   FOR (s:Statement) ON (s.embedding)
                   OPTIONS {{
                   indexConfig: {{
                    `vector.dimensions`: {EMBED_DIM},
                    `vector.similarity_function`: 'cosine'
                   }}
                }}
              """)

In [35]:
import json

TOP_K = 10
driver = GraphDatabase.driver(URI, auth=AUTH)

def graphrag_query(
    question: str,
    company_id: float | None = None,
    call_period: str | None = None,
    top_k: int = TOP_K,
) -> dict:
    """
    GraphRAG-style retrieval:
    1. Embed question
    2. Vector search on Statement nodes
    3. Expand to metrics/risks/segments + company/call
    """
    q_vec = embed_query(question)

    with driver.session() as session:
        # 1) Vector search for pivot statements
        vec_res = session.run(
            """
            CALL db.index.vector.queryNodes(
                'statement_embedding_index', $top_k, $query_vec
            )
            YIELD node, score
            WHERE ($company_id IS NULL OR EXISTS {
                     MATCH (c:Company)-[:HAS_CALL]->(:Call)-[:HAS_STATEMENT]->(node)
                     WHERE c.company_id = $company_id
                  })
              AND ($call_period IS NULL OR node.call_period = $call_period)
            RETURN node.statement_id AS statement_id, score
            """,
            {
                "top_k": top_k,
                "query_vec": q_vec,
                "company_id": company_id,
                "call_period": call_period,
            },
        ).data()

        candidate_ids = [row["statement_id"] for row in vec_res]

        if not candidate_ids:
            return {
                "question": question,
                "company_id": company_id,
                "call_period": call_period,
                "statements": [],
            }

        # 2) Expand neighborhoods: metrics, risks, segments, company, call
        result2 = session.run(
            """
            MATCH (c:Company)-[:HAS_CALL]->(call:Call)-[:HAS_STATEMENT]->(s:Statement)
            WHERE s.statement_id IN $statement_ids
            OPTIONAL MATCH (s)-[mm:MENTIONS_METRIC]->(m:Metric)
            OPTIONAL MATCH (s)-[rr:MENTIONS_RISK]->(r:Risk)
            OPTIONAL MATCH (s)-[sg:MENTIONS_SEGMENT]->(seg:Segment)
            RETURN
              s.statement_id AS statement_id,
              s.text         AS text,
              s.overall_sentiment AS overall_sentiment,
              c.company_id   AS company_id,
              c.name         AS company_name,
              call.call_id   AS call_id,
              call.period    AS call_period,
              collect(DISTINCT mm{.*, metric: m.name})    AS metrics,
              collect(DISTINCT rr{.*, risk: r.type})      AS risks,
              collect(DISTINCT sg{.*, segment: seg.name}) AS segments
            """,
            {"statement_ids": candidate_ids},
        ).data()

    # 3) Build tool-friendly JSON
    return {
        "question": question,
        "company_id": company_id,
        "call_period": call_period,
        "statements": result2,
    }


In [38]:
context = graphrag_query(
    "What did Apple say about iPhone demand and FX headwinds?",
    company_id=24937.0,
    call_period="2025 Q4",
)

import json
print(json.dumps(context, indent=2))

{
  "question": "What did Apple say about iPhone demand and FX headwinds?",
  "company_id": 24937.0,
  "call_period": "2025 Q4",
  "statements": [
    {
      "statement_id": "3573894.0_27",
      "text": "Got it. And then if I just go back to the China discussion for a minute, the performance in China, at least in September quarter was a bit muted. Could you just talk about what resulted in the weakness over there? And do you think it was a bit more of a pause given iPhone Air, for example, I don't think was available until a few weeks ago. So just somewhat what drove the weakness in September? And is the uptick of that expectation for December there just from the iPhone Air coming out? Or are there other factors as well?",
      "overall_sentiment": "neutral",
      "company_id": 24937.0,
      "company_name": "APPLE INC",
      "call_id": 3573894.0,
      "call_period": "2025 Q4",
      "metrics": [],
      "risks": [
        {
          "risk": "product availability",
          "se