In [1]:
import os
from dotenv import load_dotenv
import pandas as pd
import numpy as np
import warnings
warnings.filterwarnings('ignore')
pd.set_option('display.max_columns', None)
from IPython.display import Markdown, display, HTML

import uuid
from tqdm import tqdm
from langchain_community.document_loaders import DataFrameLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter

load_dotenv()

True

In [2]:
os.chdir(os.path.dirname(os.getcwd()))

In [3]:
df = pd.read_parquet('citation_data_with_context.parquet')
df = df.drop_duplicates(subset=['body'], keep='first')

In [4]:
# from src.parsing.graph_maker import GraphMaker, Document, Ontology

In [16]:
from typing import List, Dict
from pydantic import BaseModel, ConfigDict, Field

class OntologyLabel(BaseModel):
    """A label describing a valid ontology item for a given topic or theme."""
    
    category: str = Field(
        ...,
        description="The category label representing the entity type.",
    )
    description: str = Field(
        ...,
        description="A description or definition for the entity category.",
    )
    
    model_config = ConfigDict(
        json_schema_extra={
            "examples": [
                {   "category": "Person", 
                    "description": "A person referenced in the context."},
                {
                    "category": "Insurer",
                    "description": "An insurance company referenced in the context.",
                },
                {
                    "category": "Case Law",
                    "description": "A specific case law cited in the context.",
                },
                {
                    "category": "Insurance Coverage",
                    "description": "A specific line of insurance coverage mentioned in the context.",
                },
            ]
        }
    )

class Ontology(BaseModel):
    """An ontology for a given text and user specified theme or topic."""
    
    labels: List[OntologyLabel] = Field(
        ..., 
        default_factory=list,
    )
    model_config = ConfigDict(
        extra="allow",
        arbitrary_types_allowed=True,
    )
    
    @property
    def to_pandas(self):
        ontology_dict = {
            "category": [n.category for n in self.labels],
            "description": [n.description for n in self.labels],

        }
        return pd.DataFrame(ontology_dict)


In [17]:
OntologyLabel.model_json_schema()

{'description': 'A label describing a valid ontology item for a given topic or theme.',
 'examples': [{'category': 'Person',
   'description': 'A person referenced in the context.'},
  {'category': 'Insurer',
   'description': 'An insurance company referenced in the context.'},
  {'category': 'Case Law',
   'description': 'A specific case law cited in the context.'},
  {'category': 'Insurance Coverage',
   'description': 'A specific line of insurance coverage mentioned in the context.'}],
 'properties': {'category': {'description': 'The category label representing the entity type.',
   'title': 'Category',
   'type': 'string'},
  'description': {'description': 'A description or definition for the entity category.',
   'title': 'Description',
   'type': 'string'}},
 'required': ['category', 'description'],
 'title': 'OntologyLabel',
 'type': 'object'}

In [19]:
from typing import List
from openai import OpenAI
import instructor
from pydantic import BaseModel, Field

client = instructor.from_openai(OpenAI())


def user_message(theme: str, text: str) -> str:
    return f"While focusing on the theme **{theme}**, generate an ontology for the following input text: ```\n{text}\n```"

def system_message() -> str:
    return (
        "You are an expert at creating an ontology for a given theme or topic. "
        "Users will provide you with a **theme** and an input text delimited by ```. "
        "Extract all the entity types from the input text relevant to the **theme**. "
        "The goal is to create an ontology to use for downstream knowledge graph construction for the **theme**."
    )


def generate_ontology(theme: str, text: str, model: str) -> str:
    return client.chat.completions.create(
        model=model,
        max_retries=3,
        messages=[
            {
                "role": "system",
                "content": system_message(),
            },
            {
                "role": "user",
                "content": user_message(theme=theme, text=text),
            },
        ],
        response_model=Ontology,
    )


In [8]:
example_ontology = Ontology(
    labels=[
        {"Person": "Person name without any adjectives, Remember a person may be referenced by their name or using a pronoun"},
        {"Object": "Do not add the definite article 'the' in the object name"},
        {"Event": "Event event involving multiple people. Do not include qualifiers or verbs like gives, leaves, works etc."},
        "Place",
        "Document",
        "Organization",
        "Action",
        {"Miscellaneous": "Any important concept can not be categorized with any other given label"},
    ],
    relationships=[
        "Relation between any pair of Entities"
        ],
)

In [9]:
example_ontology.dump

{'labels': [{'Person': 'Person name without any adjectives, Remember a person may be referenced by their name or using a pronoun'},
  {'Object': "Do not add the definite article 'the' in the object name"},
  {'Event': 'Event event involving multiple people. Do not include qualifiers or verbs like gives, leaves, works etc.'},
  'Place',
  'Document',
  'Organization',
  'Action',
  {'Miscellaneous': 'Any important concept can not be categorized with any other given label'}],
 'relationships': ['Relation between any pair of Entities']}

In [6]:
from typing import List
from openai import OpenAI
import instructor
from pydantic import BaseModel, Field

client = instructor.from_openai(OpenAI())
model = "gpt-3.5-turbo"


class Node(BaseModel):
    label: str
    name: str


class Edge(BaseModel):
    node_1: Node
    node_2: Node
    relationship: str
    

class KnowledgeGraph(BaseModel):
    edges: List[Edge] = Field(..., default_factory=list)
    
    @property
    def to_pandas(self):
        kg_dict = {
            "node_1": [n.node_1.name for n in self.edges],
            "node_2": [n.node_2.name for n in self.edges],
            "edge": [n.relationship for n in self.edges],
            "node_1_type": [n.node_1.label for n in self.edges],
            "node_2_type": [n.node_2.label for n in self.edges],
        }
        return pd.DataFrame(kg_dict)


def user_message(text: str) -> str:
    return f"input text: ```\n{text}\n```"

def system_message(ontology: Ontology = example_ontology) -> str:
    return (
        "You are an expert at creating Knowledge Graphs. "
        "Consider the following ontology. \n"
        f"{ontology} \n"
        "The user will provide you with an input text delimited by ```. "
        "Extract all the entities and relationships from the user-provided text as per the given ontology. Do not use any previous knowledge about the context."
        "Remember there can be multiple direct (explicit) or implied relationships between the same pair of nodes. "
        "Be consistent with the given ontology. Use ONLY the labels and relationships mentioned in the ontology. "
        "Remember to follow the correct format, for example:\n"
        "[\n"
        "   {\n"
        '       node_1: Required, an entity object with attributes: {"label": "as per the ontology", "name": "Name of the entity"},\n'
        '       node_2: Required, an entity object with attributes: {"label": "as per the ontology", "name": "Name of the entity"},\n'
        "       relationship: Describe the relationship between node_1 and node_2 as per the context, in one or two sentences.\n"
        "   },\n"
        "]\n"
    )


def generate_graph(text: str, model: str) -> str:
    return client.chat.completions.create(
        model=model,
        max_retries=3,
        messages=[
            {
                "role": "system",
                "content": system_message(),
            },
            {
                "role": "user",
                "content": user_message(text=text),
            },
        ],
        response_model=KnowledgeGraph,
    )


In [67]:
df.head(1)

Unnamed: 0,id,citation,name,name_abbreviation,decision_date,court_id,court_name,court_slug,judges,attorneys,citations,url,head,body,name_contains_lm,body_contains_lm,year,context,context_citation,context_tokens
0,411690,154 Ill. 2d 90,"RICHARD R. JOHNSON, Plaintiff-Appellant and Cr...",Johnson v. Halloran,2000-01-13,8837,Illinois Appellate Court,ill-app-ct,[],"['Wolter, Beeman, Lynch & McIntyre, of Springf...","[{'type': 'official', 'cite': '312 Ill. App. 3...",https://api.case.law/v1/cases/411690/,"RICHARD R. JOHNSON, Plaintiff-Appellant and Cr...",JUSTICE HALL\r\ndelivered the opinion of the c...,False,True,2000,The public defender of Cook County was appoint...,154 Ill. 2d 90,1317


In [20]:
sample_test = df.sample(1)
text_column = 'body'

splitter = RecursiveCharacterTextSplitter(
    chunk_size=5000,
    chunk_overlap=0,
    length_function=len,
    is_separator_regex=False,
)

def dataframe2Documents(df: pd.DataFrame, text_column: str):
    loader = DataFrameLoader(df, page_content_column=text_column)
    return loader.load()

documents = dataframe2Documents(sample_test, text_column)
docs = splitter.split_documents(documents)
len(docs)

4

In [21]:
def documents2Dataframe(documents) -> pd.DataFrame:
    rows = []
    for chunk in documents:
        row = {
            "text": chunk.page_content,
            **chunk.metadata,
            "chunk_id": uuid.uuid4().hex,
        }
        rows = rows + [row]
    df = pd.DataFrame(rows)
    return df

In [22]:
df = documents2Dataframe(docs)
print(df.shape)
df.head(2)

(4, 21)


Unnamed: 0,text,id,citation,name,name_abbreviation,decision_date,court_id,court_name,court_slug,judges,attorneys,citations,url,head,name_contains_lm,body_contains_lm,year,context,context_citation,context_tokens,chunk_id
0,PRESIDING JUSTICE HOFFMAN\r\ndelivered the opi...,3739664,197 Ill. 2d 28,"HOME INSURANCE COMPANY, as the Successor in In...",Home Insurance v. Cincinnati Insurance,2003-12-03,8837,Illinois Appellate Court,ill-app-ct,"['KARNEZIS, J., concurs.']","['Pretzel & Stouffer, of Chicago (Robert Marc ...","[{'type': 'official', 'cite': '345 Ill. App. 3...",https://api.case.law/v1/cases/3739664/,"HOME INSURANCE COMPANY, as the Successor in In...",False,True,2003,After the trial in the Fisher action commenced...,197 Ill. 2d 28,1361,2bf1200a140f4ccbaf61adf1c398fa26
1,Cincinnati and Home filed cross-motions for su...,3739664,197 Ill. 2d 28,"HOME INSURANCE COMPANY, as the Successor in In...",Home Insurance v. Cincinnati Insurance,2003-12-03,8837,Illinois Appellate Court,ill-app-ct,"['KARNEZIS, J., concurs.']","['Pretzel & Stouffer, of Chicago (Robert Marc ...","[{'type': 'official', 'cite': '345 Ill. App. 3...",https://api.case.law/v1/cases/3739664/,"HOME INSURANCE COMPANY, as the Successor in In...",False,True,2003,After the trial in the Fisher action commenced...,197 Ill. 2d 28,1361,220be81ed416456bbf0e24f40002f1dd


In [23]:
def df2Ontology(theme: str, df:pd.DataFrame, model: str ="gpt-4o") -> pd.DataFrame:
    progress_bar = tqdm(total=len(df), desc="Processing chunks")

    def apply_ontologyPrompt(row):
        result = generate_ontology(theme=theme, text=row.text, model=model)
        result_df = result.to_pandas
        result_df["chunk_id"] = row.chunk_id
        progress_bar.update(1)
        return result_df
    
    results = df.apply(apply_ontologyPrompt, axis=1)
    results_df = pd.concat(results.tolist(), ignore_index=True)
    progress_bar.close()
    
    return results_df

In [25]:
test = df2Ontology(
    theme="Insurance Coverage",
    df=df)

Processing chunks:   0%|          | 0/4 [00:00<?, ?it/s]

Processing chunks: 100%|██████████| 4/4 [00:13<00:00,  3.40s/it]


In [27]:
Markdown(test.to_markdown())

|    | category               | description                                                                                                                    | chunk_id                         |
|---:|:-----------------------|:-------------------------------------------------------------------------------------------------------------------------------|:---------------------------------|
|  0 | Insurance Company      | An entity that provides insurance coverage.                                                                                    | 2bf1200a140f4ccbaf61adf1c398fa26 |
|  1 | Insurance Policy       | A specific line of insurance coverage mentioned in the context.                                                                | 2bf1200a140f4ccbaf61adf1c398fa26 |
|  2 | Judicial Entity        | A court or a justice involved in the legal proceedings.                                                                        | 2bf1200a140f4ccbaf61adf1c398fa26 |
|  3 | Contractor             | A company or entity employed to perform construction or renovation work.                                                       | 2bf1200a140f4ccbaf61adf1c398fa26 |
|  4 | Subcontractor          | A company or entity employed by the contractor to carry out specific parts of the work.                                        | 2bf1200a140f4ccbaf61adf1c398fa26 |
|  5 | Litigation Party       | A party involved in litigation or legal proceedings.                                                                           | 2bf1200a140f4ccbaf61adf1c398fa26 |
|  6 | Legal Action           | A legal proceeding or action taken by one party against another.                                                               | 2bf1200a140f4ccbaf61adf1c398fa26 |
|  7 | Settlement             | An agreement reached between parties to resolve a legal dispute.                                                               | 2bf1200a140f4ccbaf61adf1c398fa26 |
|  8 | Legal Concept          | A legal idea or principle relevant to the case.                                                                                | 2bf1200a140f4ccbaf61adf1c398fa26 |
|  9 | Accident/Injury        | An incident where harm or injury occurred, triggering the insurance claims.                                                    | 2bf1200a140f4ccbaf61adf1c398fa26 |
| 10 | Person                 | Individuals mentioned in the context, including legal professionals and corporate representatives.                             | 220be81ed416456bbf0e24f40002f1dd |
| 11 | Insurer                | Insurance companies referenced in the context.                                                                                 | 220be81ed416456bbf0e24f40002f1dd |
| 12 | Insurance Coverage     | Types of insurance coverage mentioned in the context, including primary and excess coverage.                                   | 220be81ed416456bbf0e24f40002f1dd |
| 13 | Legal Action           | Specific legal actions or motions noted in the context.                                                                        | 220be81ed416456bbf0e24f40002f1dd |
| 14 | Legal Principle        | Legal doctrines or principles referenced in the context.                                                                       | 220be81ed416456bbf0e24f40002f1dd |
| 15 | Case Law               | Specific legal cases cited as references.                                                                                      | 220be81ed416456bbf0e24f40002f1dd |
| 16 | Statute                | Laws or statutes referenced in the context.                                                                                    | 220be81ed416456bbf0e24f40002f1dd |
| 17 | Organization           | Entities or firms mentioned that are not necessarily insurers.                                                                 | 220be81ed416456bbf0e24f40002f1dd |
| 18 | Policy Clause          | Specific clauses within an insurance policy.                                                                                   | 220be81ed416456bbf0e24f40002f1dd |
| 19 | Case Law               | A specific case law cited in the context.                                                                                      | 6d52e74566b0447e9afbaaf44d580550 |
| 20 | Insurer                | An insurance company referenced in the context.                                                                                | 6d52e74566b0447e9afbaaf44d580550 |
| 21 | Insurance Coverage     | A specific line of insurance coverage mentioned in the context.                                                                | 6d52e74566b0447e9afbaaf44d580550 |
| 22 | Legal Concept          | A legal concept or principle relevant to insurance coverage.                                                                   | 6d52e74566b0447e9afbaaf44d580550 |
| 23 | Plaintiff              | An individual who brings a case against another in a court of law.                                                             | 6d52e74566b0447e9afbaaf44d580550 |
| 24 | Defendant              | An individual, company, or institution sued or accused in a court of law.                                                      | 6d52e74566b0447e9afbaaf44d580550 |
| 25 | Court                  | A specific court where the legal case was considered.                                                                          | 6d52e74566b0447e9afbaaf44d580550 |
| 26 | Policyholder           | An individual or entity holding an insurance policy.                                                                           | 6d52e74566b0447e9afbaaf44d580550 |
| 27 | Settlement             | An amount paid in settlement of a claim against a policyholder.                                                                | 6d52e74566b0447e9afbaaf44d580550 |
| 28 | Verdict                | The decision reached by a jury in a court case.                                                                                | 6d52e74566b0447e9afbaaf44d580550 |
| 29 | Insurance Company      | A company that provides insurance coverage.                                                                                    | e16e40749abd4ac7907381e4c60c1b7e |
| 30 | Policyholder           | An individual or entity that holds an insurance policy.                                                                        | e16e40749abd4ac7907381e4c60c1b7e |
| 31 | Insurance Policy       | A document detailing the terms and conditions of a contract of insurance.                                                      | e16e40749abd4ac7907381e4c60c1b7e |
| 32 | Legal Case             | A legal dispute brought before a court for adjudication.                                                                       | e16e40749abd4ac7907381e4c60c1b7e |
| 33 | Court                  | A governmental institution with the authority to adjudicate legal disputes.                                                    | e16e40749abd4ac7907381e4c60c1b7e |
| 34 | Equitable Contribution | A legal principle allowing one insurer to seek a proportionate share of liability from another insurer covering the same risk. | e16e40749abd4ac7907381e4c60c1b7e |
| 35 | Primary Insurance      | Insurance that provides immediate coverage upon the occurrence of a loss.                                                      | e16e40749abd4ac7907381e4c60c1b7e |
| 36 | Excess Insurance       | Insurance that provides additional coverage after the primary insurance limits are exhausted.                                  | e16e40749abd4ac7907381e4c60c1b7e |
| 37 | Injury                 | Physical harm or damage to a person.                                                                                           | e16e40749abd4ac7907381e4c60c1b7e |
| 38 | Endorsement            | An amendment to an insurance policy that changes its terms or conditions.                                                      | e16e40749abd4ac7907381e4c60c1b7e |
| 39 | Settlement             | An agreement reached between parties in a legal dispute, typically involving payment to the injured party.                     | e16e40749abd4ac7907381e4c60c1b7e |
| 40 | Additional Insured     | An entity or individual covered by an insurance policy in addition to the primary policyholder.                                | e16e40749abd4ac7907381e4c60c1b7e |

In [11]:
def df2Graph(df, model="gpt-4o") -> pd.DataFrame:
    progress_bar = tqdm(total=len(df), desc="Processing chunks")

    def apply_graphPrompt(row):
        result = generate_graph(row.text, model)
        result_df = result.to_pandas
        result_df["chunk_id"] = row.chunk_id
        progress_bar.update(1)
        return result_df
    
    results = df.apply(apply_graphPrompt, axis=1)
    results_df = pd.concat(results.tolist(), ignore_index=True)
    progress_bar.close()
    
    return results_df

In [12]:
dfg = df2Graph(df)

Processing chunks:   0%|          | 0/5 [00:00<?, ?it/s]

INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
Processing chunks:  20%|██        | 1/5 [00:30<02:00, 30.18s/it]INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
Processing chunks:  40%|████      | 2/5 [00:53<01:17, 25.98s/it]INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
Processing chunks:  60%|██████    | 3/5 [01:03<00:37, 18.92s/it]INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
Processing chunks:  80%|████████  | 4/5 [01:21<00:18, 18.38s/it]INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
Processing chunks: 100%|██████████| 5/5 [01:27<00:00, 17.60s/it]


In [112]:
print(dfg.shape)
dfg.head()

(72, 6)


Unnamed: 0,node_1,node_2,edge,node_1_type,node_2_type,chunk_id
0,JUSTICE GEIGER,opinion of the court,delivered,Person,Miscellaneous,f3b37b0a5f5f493caa3dc83df34fb354
1,Konami,patent infringement lawsuit,was sued by a business competitor for,Organization,Event,f3b37b0a5f5f493caa3dc83df34fb354
2,Hartford Insurance Company of Illinois,Konami,was tendered defense of by the plaintiff,Organization,Organization,f3b37b0a5f5f493caa3dc83df34fb354
3,Konami,Hartford Insurance Company of Illinois,brought a breach of contract action against,Organization,Organization,f3b37b0a5f5f493caa3dc83df34fb354
4,Hartford Insurance Company of Illinois,Konami,provided insurance policy to,Organization,Organization,f3b37b0a5f5f493caa3dc83df34fb354


In [16]:
dfg['node_1'].nunique(), dfg['node_2'].nunique()

(38, 33)

In [113]:
from openai import AsyncOpenAI
import asyncio
from tqdm.asyncio import tqdm_asyncio

class NodeEntity(BaseModel):
    """The original and resolved entity name from a list of nodes."""
    
    original_name: str = Field(
        ...,
        description="The original entity given by the user.",
    )
    resolved_name: str = Field(
        ...,
        description="The name of the entity such that duplications are resolved and issues with punctuation or capitalization are corrected.",
    )
    
    
class ResolvedEntities(BaseModel):
    """A list of entities."""
    
    entities: List[NodeEntity] = Field(..., default_factory=list)
    
    @property
    def to_pandas(self):
        entity_dict = {
            "original_name": [n.original_name for n in self.entities],
            "resolved_name": [n.resolved_name for n in self.entities],
        }
        return pd.DataFrame(entity_dict)
    
    
def system_message() -> str:
    return (
        "You are an expert entity resolution AI. "
        "Users will provide you with a list of Node names representing entities from a knowledge graph separated by new lines. "
        "Your task is to, for each entity, generate a NodeEntity such that duplicates are removed by determining their resolved name. "
        "Try to infer the base name that uniquely describes the entity as concisely as possible. "
        "For example, 'Liberty Mutual Group', 'Liberty Mutual Insurance Co', 'Liberty Mutual Company of Massachusets', 'Liberty Mutual Casualty Insurance' "
        "should all simply be 'Liberty Mutual'. "
        "Please also correct any punctuation or capitalization issues."
    )


async def resolve_entities(entities: List[str], model: str = 'gpt-4o') -> str:
    entity_list_string = "\n ".join(entities)
    client = instructor.from_openai(AsyncOpenAI())
    return await client.chat.completions.create(
        model=model,
        max_retries=5,
        messages=[
            {
                "role": "system",
                "content": system_message(),
            },
            {
                "role": "user",
                "content": f"Here is the list of entities to resolve:\n\n{entity_list_string}",
            },
        ],
        response_model=ResolvedEntities,
    )
  
  
async def process_batch(df_batch: pd.DataFrame, column_name: str, model: str) -> pd.DataFrame:
    entities = df_batch[column_name].tolist()
    resolved_entities = await resolve_entities(entities, model)
    resolved_df = resolved_entities.to_pandas
    resolved_df.index = df_batch.index
    return df_batch.join(resolved_df.set_index("original_name"), on=column_name)


async def process_dataframe(df: pd.DataFrame, column_name: str, model: str, batch_size: int = 50) -> pd.DataFrame:
    tasks = []
    for start in range(0, len(df), batch_size):
        df_batch = df.iloc[start:start + batch_size]
        tasks.append(process_batch(df_batch, column_name, model))
    
    results = []
    for f in tqdm_asyncio.as_completed(tasks, desc="Processing batches"):
        result = await f
        results.append(result)
    
    return pd.concat(results, ignore_index=True)

# Example usage
# import nest_asyncio
# nest_asyncio.apply()

# column_name = "node"
# model = "gpt-4o"

# resolved_df = asyncio.run(process_dataframe(df, column_name, model))

In [114]:
import nest_asyncio
nest_asyncio.apply()

column_name = "node_1"
model = "gpt-4o"

resolved_df = asyncio.run(process_dataframe(dfg, column_name, model))

Processing batches:   0%|          | 0/2 [00:00<?, ?it/s]INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
Processing batches:  50%|█████     | 1/2 [00:04<00:04,  4.88s/it]INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
Processing batches:  50%|█████     | 1/2 [00:06<00:06,  6.06s/it]


ValueError: Length mismatch: Expected axis has 23 elements, new values have 50 elements

In [115]:
dfg.shape

(72, 6)

In [17]:
def contextual_proximity(df):
    # Melt the dataframe into a list of nodes
    dfg_long = pd.melt(
        df,
        id_vars=["chunk_id"],
        value_vars=["node_1", "node_2"],
        value_name="node",
    )
    dfg_long.drop(columns=["variable"], inplace=True)
    # Self join with chunk id as the key will create a link between terms occurring in the same text chunk
    dfg_wide = pd.merge(dfg_long, dfg_long, on="chunk_id", suffixes=("_1", "_2"))
    # drop self loops
    self_loops_drop = dfg_wide[dfg_wide["node_1"] == dfg_wide["node_2"]].index
    dfg2 = dfg_wide.drop(index=self_loops_drop).reset_index(drop=True)
    # Group and count edges
    dfg2 = (
        dfg2.groupby(["node_1", "node_2"])
        .agg({"chunk_id": [",".join, "count"]})
        .reset_index()
    )
    dfg2.columns = ["node_1", "node_2", "chunk_id", "count"]
    dfg2.replace("", np.nan, inplace=True)
    dfg2.dropna(subset=["node_1", "node_2"], inplace=True)
    # Drop edges with 1 count
    dfg2 = dfg2[dfg2["count"] != 1]
    dfg2["edge"] = "chunk contextual proximity"
    return dfg2

In [18]:
dfg2 = contextual_proximity(dfg)

In [19]:
dfg2

Unnamed: 0,node_1,node_2,chunk_id,count,edge
4,"$984,943.15",Hartford,"0d1336884d5b4d1b9f43d4144b3e7e67,0d1336884d5b4...",7,chunk contextual proximity
5,"$984,943.15",Hartford Casualty,"0d1336884d5b4d1b9f43d4144b3e7e67,0d1336884d5b4...",2,chunk contextual proximity
6,"$984,943.15","January 17, 1995","0d1336884d5b4d1b9f43d4144b3e7e67,0d1336884d5b4...",2,chunk contextual proximity
9,"$984,943.15",Konami,"0d1336884d5b4d1b9f43d4144b3e7e67,0d1336884d5b4...",17,chunk contextual proximity
10,"$984,943.15",Land & Sky,"0d1336884d5b4d1b9f43d4144b3e7e67,0d1336884d5b4...",2,chunk contextual proximity
...,...,...,...,...,...
1083,trial,Union Insurance Co.,"0d1336884d5b4d1b9f43d4144b3e7e67,0d1336884d5b4...",5,chunk contextual proximity
1084,trial,advertising injury provision,"0d1336884d5b4d1b9f43d4144b3e7e67,0d1336884d5b4...",10,chunk contextual proximity
1085,trial,appeal,"0d1336884d5b4d1b9f43d4144b3e7e67,0d1336884d5b4...",10,chunk contextual proximity
1086,trial,section 155,"0d1336884d5b4d1b9f43d4144b3e7e67,0d1336884d5b4...",5,chunk contextual proximity


In [20]:
dfg_combined = pd.concat([dfg, dfg2])

In [21]:
print(dfg_combined.shape)
dfg_combined.head()

(710, 7)


Unnamed: 0,node_1,node_2,edge,node_1_type,node_2_type,chunk_id,count
0,JUSTICE GEIGER,opinion of the court,delivered,Person,Miscellaneous,f3b37b0a5f5f493caa3dc83df34fb354,
1,Konami,patent infringement lawsuit,was sued by a business competitor for,Organization,Event,f3b37b0a5f5f493caa3dc83df34fb354,
2,Hartford Insurance Company of Illinois,Konami,was tendered defense of by the plaintiff,Organization,Organization,f3b37b0a5f5f493caa3dc83df34fb354,
3,Konami,Hartford Insurance Company of Illinois,brought a breach of contract action against,Organization,Organization,f3b37b0a5f5f493caa3dc83df34fb354,
4,Hartford Insurance Company of Illinois,Konami,provided insurance policy to,Organization,Organization,f3b37b0a5f5f493caa3dc83df34fb354,


In [22]:
dfg_final = (
    dfg_combined.groupby(["node_1", "node_2"])
    .agg(
        {
            # Convert to string and concatenate unique, non-null chunk_ids
            "chunk_id": lambda x: ",".join(
                set([str(i) for i in x if pd.notna(i)])
            ),
            # Convert to string and concatenate unique, non-null edge descriptions
            "edge": lambda x: ",".join(
                set([str(i) for i in x if pd.notna(i)])
            ),
            # Sum the weights (counts)
            "count": "sum",
        }
    )
    .reset_index()
)

In [23]:
print(dfg_final.shape)
dfg_final.head()

(648, 5)


Unnamed: 0,node_1,node_2,chunk_id,edge,count
0,"$984,943.15",Hartford,"0d1336884d5b4d1b9f43d4144b3e7e67,0d1336884d5b4...",chunk contextual proximity,7.0
1,"$984,943.15",Hartford Casualty,"0d1336884d5b4d1b9f43d4144b3e7e67,0d1336884d5b4...",chunk contextual proximity,2.0
2,"$984,943.15","January 17, 1995","0d1336884d5b4d1b9f43d4144b3e7e67,0d1336884d5b4...",chunk contextual proximity,2.0
3,"$984,943.15",Konami,"0d1336884d5b4d1b9f43d4144b3e7e67,0d1336884d5b4...",chunk contextual proximity,17.0
4,"$984,943.15",Land & Sky,"0d1336884d5b4d1b9f43d4144b3e7e67,0d1336884d5b4...",chunk contextual proximity,2.0


In [24]:
nodes = pd.concat(
    [dfg_final["node_1"], dfg_final["node_2"]], axis=0
).unique()

len(nodes)

59

In [25]:
import networkx as nx

G = nx.Graph()

In [26]:
for node in nodes:
    G.add_node(str(node))
# Add edges to the graph
for _, row in dfg_final.iterrows():
    G.add_edge(
        str(row["node_1"]),
        str(row["node_2"]),
        title=row["edge"],
        weight=row["count"] / 4,
        chunk_id=row["chunk_id"],
    )

In [27]:
def detect_communities(G):
    communities_generator = nx.community.girvan_newman(G)
    next_level_communities = next(communities_generator)
    communities = sorted(map(sorted, next_level_communities))
    return communities

In [28]:
communities = detect_communities(G)

In [29]:
import logging
import random
import seaborn as sns


palette = "hls"
p = sns.color_palette(palette, len(communities)).as_hex()
random.shuffle(p)
rows = []
group = 0
for community in communities:
    color = p.pop()
    group += 1
    for node in community:
        rows += [{"node": node, "color": color, "group": group}]
df_colors = pd.DataFrame(rows)

for _, row in df_colors.iterrows():
    G.nodes[row["node"]]["group"] = row["group"]
    G.nodes[row["node"]]["color"] = row["color"]
    G.nodes[row["node"]]["size"] = G.degree[row["node"]]


In [30]:
import json

graph_data = nx.node_link_data(G)

# Specify the file path where you want to save the JSON
json_file_path = "data/graph_data.json"

# Write the graph data to a JSON file
with open(json_file_path, "w", encoding="utf-8-sig") as json_file:
    json.dump(graph_data, json_file, ensure_ascii=False)

In [31]:
graph_data["links"]

[{'title': 'chunk contextual proximity',
  'weight': 1.75,
  'chunk_id': '0d1336884d5b4d1b9f43d4144b3e7e67,0d1336884d5b4d1b9f43d4144b3e7e67,0d1336884d5b4d1b9f43d4144b3e7e67,0d1336884d5b4d1b9f43d4144b3e7e67,0d1336884d5b4d1b9f43d4144b3e7e67,0d1336884d5b4d1b9f43d4144b3e7e67,0d1336884d5b4d1b9f43d4144b3e7e67',
  'source': '$984,943.15',
  'target': 'Hartford'},
 {'title': 'chunk contextual proximity',
  'weight': 0.5,
  'chunk_id': '0d1336884d5b4d1b9f43d4144b3e7e67,0d1336884d5b4d1b9f43d4144b3e7e67',
  'source': '$984,943.15',
  'target': 'Hartford Casualty'},
 {'title': 'chunk contextual proximity',
  'weight': 0.5,
  'chunk_id': '0d1336884d5b4d1b9f43d4144b3e7e67,0d1336884d5b4d1b9f43d4144b3e7e67',
  'source': '$984,943.15',
  'target': 'January 17, 1995'},
 {'title': 'chunk contextual proximity,The amount included in the judgment in favor of Konami.',
  'weight': 4.25,
  'chunk_id': '0d1336884d5b4d1b9f43d4144b3e7e67,0d1336884d5b4d1b9f43d4144b3e7e67,0d1336884d5b4d1b9f43d4144b3e7e67,0d1336884

In [34]:
def load_graph_from_json(json_file_path):
    try:
        with open(json_file_path, 'r', encoding='utf-8-sig') as file:
            json_data = json.load(file)
    except FileNotFoundError:
        print(f"File not found: {json_file_path}")
        return None
    except json.JSONDecodeError as e:
        print(f"Error decoding JSON from the file: {json_file_path}. Error: {e}")
        return None
    except Exception as e:
        print(f"Unexpected error while reading the file: {e}")
        return None

    G = nx.Graph()
    for node in json_data['nodes']:
        G.add_node(node['id'])
    for link in json_data['links']:
        G.add_edge(link['source'], link['target'], weight=link['weight'], title=link['title'])
    return G


# function to load df from json for chunk retrieval based on node and chunk id
def load_chunks_dataframe(json_file_path):
    try:
        with open(json_file_path, 'r', encoding='utf-8-sig') as file:
            data = json.load(file)

        # Extracting chunks and their IDs
        chunks = []
        for link in data['links']:
            chunk_ids = link.get('chunk_id', '').split(',')
            text = link.get('title', '')  # Assuming 'title' contains the text associated with the chunk
            for chunk_id in chunk_ids:
                if chunk_id:
                    chunks.append({'chunk_id': chunk_id, 'text': text})

        return pd.DataFrame(chunks)

    except Exception as e:
        print(f"Error: {e}")
        return pd.DataFrame()  # Return an empty DataFrame in case of an error

In [35]:
# Load the graph
graph = load_graph_from_json('data/graph_data.json')
# load the chunk_dataframe
chunks_dataframe = load_chunks_dataframe('data/graph_data.json')

In [37]:
chunks_dataframe[chunks_dataframe['text']!='chunk contextual proximity']

Unnamed: 0,chunk_id,text
11,0d1336884d5b4d1b9f43d4144b3e7e67,"chunk contextual proximity,The amount included..."
12,0d1336884d5b4d1b9f43d4144b3e7e67,"chunk contextual proximity,The amount included..."
13,0d1336884d5b4d1b9f43d4144b3e7e67,"chunk contextual proximity,The amount included..."
14,0d1336884d5b4d1b9f43d4144b3e7e67,"chunk contextual proximity,The amount included..."
15,0d1336884d5b4d1b9f43d4144b3e7e67,"chunk contextual proximity,The amount included..."
...,...,...
2070,7673b452763e4cb3b6cbca0d8abc3e83,Direct patent infringement refers to the makin...
2077,7673b452763e4cb3b6cbca0d8abc3e83,"chunk contextual proximity,Some dictionaries d..."
2078,7673b452763e4cb3b6cbca0d8abc3e83,"chunk contextual proximity,Some dictionaries d..."
2079,7673b452763e4cb3b6cbca0d8abc3e83,"chunk contextual proximity,Some dictionaries d..."


In [38]:
import re


def textualize_graph(graph):
    triplets = re.findall(r'\((.*?)\)', graph)
    nodes = {}
    edges = []
    for tri in triplets:
        src, edeg_attr, dst = tri.split(';')
        src = src.lower().strip()
        dst = dst.lower().strip()
        if src not in nodes:
            nodes[src] = len(nodes)
        if dst not in nodes:
            nodes[dst] = len(nodes)
        edges.append({'src': nodes[src], 'edge_attr': edeg_attr.lower().strip(), 'dst': nodes[dst], })

    nodes = pd.DataFrame(nodes.items(), columns=['node_attr', 'node_id'])
    edges = pd.DataFrame(edges)
    return nodes, edges

In [39]:
from src.embedding_models.models import OpenAIEmbeddings
import nest_asyncio

embeddings = OpenAIEmbeddings()
embedding_fn = embeddings.embedding_fn()

In [43]:
node_features = pd.DataFrame(graph_data["nodes"])
node_embeddings = embedding_fn(node_features["id"].tolist())
node_features["embeddings"] = node_embeddings

INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


In [44]:
print(node_features.shape)
node_features.head()

(59, 5)


Unnamed: 0,group,color,size,id,embeddings
0,1,#db5f57,9,"$984,943.15","[0.01951550878584385, -0.01489550992846489, 0...."
1,1,#db5f57,6,935 F. Supp. at 1116,"[-0.00864872895181179, 0.005267801228910685, 0..."
2,1,#db5f57,5,Advertising mode,"[-0.006128218956291676, -0.011504832655191422,..."
3,1,#db5f57,9,"April 16, 1999","[-0.0065999156795442104, -0.031243011355400085..."
4,1,#db5f57,9,"April 23, 1996","[-0.001258106785826385, -0.008358835242688656,..."


In [45]:
edge_features = pd.DataFrame(graph_data["links"])
edge_embeddings = embedding_fn(edge_features["title"].tolist())
node_embeddings = embedding_fn(edge_features["source"].tolist())
edge_features["embeddings"] = edge_embeddings

INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


In [46]:
edge_features

Unnamed: 0,title,weight,chunk_id,source,target,embeddings
0,chunk contextual proximity,1.75,"0d1336884d5b4d1b9f43d4144b3e7e67,0d1336884d5b4...","$984,943.15",Hartford,"[-0.025688467547297478, -0.01057156641036272, ..."
1,chunk contextual proximity,0.50,"0d1336884d5b4d1b9f43d4144b3e7e67,0d1336884d5b4...","$984,943.15",Hartford Casualty,"[-0.025688467547297478, -0.01057156641036272, ..."
2,chunk contextual proximity,0.50,"0d1336884d5b4d1b9f43d4144b3e7e67,0d1336884d5b4...","$984,943.15","January 17, 1995","[-0.025688467547297478, -0.01057156641036272, ..."
3,"chunk contextual proximity,The amount included...",4.25,"0d1336884d5b4d1b9f43d4144b3e7e67,0d1336884d5b4...","$984,943.15",Konami,"[0.0019303852459415793, -0.0123488400131464, 0..."
4,chunk contextual proximity,0.50,"0d1336884d5b4d1b9f43d4144b3e7e67,0d1336884d5b4...","$984,943.15",Land & Sky,"[-0.025688467547297478, -0.01057156641036272, ..."
...,...,...,...,...,...,...
324,"chunk contextual proximity,Some dictionaries d...",0.75,"7673b452763e4cb3b6cbca0d8abc3e83,7673b452763e4...",patent infringement,piracy,"[-0.01839851774275303, -0.02950393036007881, 0..."
325,chunk contextual proximity,0.75,"7673b452763e4cb3b6cbca0d8abc3e83,7673b452763e4...",patent infringement,sale of a patented component,"[-0.025586551055312157, -0.01060118991881609, ..."
326,chunk contextual proximity,0.50,"0d1336884d5b4d1b9f43d4144b3e7e67,0d1336884d5b4...",section 155,summary judgment,"[-0.025688467547297478, -0.01057156641036272, ..."
327,chunk contextual proximity,1.25,"0d1336884d5b4d1b9f43d4144b3e7e67,0d1336884d5b4...",section 155,trial,"[-0.025688467547297478, -0.01057156641036272, ..."


In [47]:
from torch_geometric.utils import from_networkx
import torch

In [48]:
data = from_networkx(G)

In [49]:
data.x = torch.tensor(node_embeddings, dtype=torch.float)
data.edge_attr = torch.tensor(edge_embeddings, dtype=torch.float)

In [50]:
import torch
import numpy as np
from pcst_fast import pcst_fast
from torch_geometric.data.data import Data


def retrieval_via_pcst(graph, q_emb, textual_nodes, textual_edges, topk=3, topk_e=3, cost_e=0.5):
    c = 0.01
    if len(textual_nodes) == 0 or len(textual_edges) == 0:
        desc = textual_nodes.to_csv(index=False) + '\n' + textual_edges.to_csv(index=False, columns=['src', 'edge_attr', 'dst'])
        graph = Data(x=graph.x, edge_index=graph.edge_index, edge_attr=graph.edge_attr, num_nodes=graph.num_nodes)
        return graph, desc

    root = -1  # unrooted
    num_clusters = 1
    pruning = 'gw'
    verbosity_level = 0
    if topk > 0:
        n_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, graph.x)
        topk = min(topk, graph.num_nodes)
        _, topk_n_indices = torch.topk(n_prizes, topk, largest=True)

        n_prizes = torch.zeros_like(n_prizes)
        n_prizes[topk_n_indices] = torch.arange(topk, 0, -1).float()
    else:
        n_prizes = torch.zeros(graph.num_nodes)

    if topk_e > 0:
        e_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, graph.edge_attr)
        topk_e = min(topk_e, e_prizes.unique().size(0))

        topk_e_values, _ = torch.topk(e_prizes.unique(), topk_e, largest=True)
        e_prizes[e_prizes < topk_e_values[-1]] = 0.0
        last_topk_e_value = topk_e
        for k in range(topk_e):
            indices = e_prizes == topk_e_values[k]
            value = min((topk_e-k)/sum(indices), last_topk_e_value)
            e_prizes[indices] = value
            last_topk_e_value = value*(1-c)
        # reduce the cost of the edges such that at least one edge is selected
        cost_e = min(cost_e, e_prizes.max().item()*(1-c/2))
    else:
        e_prizes = torch.zeros(graph.num_edges)

    costs = []
    edges = []
    vritual_n_prizes = []
    virtual_edges = []
    virtual_costs = []
    mapping_n = {}
    mapping_e = {}
    for i, (src, dst) in enumerate(graph.edge_index.T.numpy()):
        prize_e = e_prizes[i]
        if prize_e <= cost_e:
            mapping_e[len(edges)] = i
            edges.append((src, dst))
            costs.append(cost_e - prize_e)
        else:
            virtual_node_id = graph.num_nodes + len(vritual_n_prizes)
            mapping_n[virtual_node_id] = i
            virtual_edges.append((src, virtual_node_id))
            virtual_edges.append((virtual_node_id, dst))
            virtual_costs.append(0)
            virtual_costs.append(0)
            vritual_n_prizes.append(prize_e - cost_e)

    prizes = np.concatenate([n_prizes, np.array(vritual_n_prizes)])
    num_edges = len(edges)
    if len(virtual_costs) > 0:
        costs = np.array(costs+virtual_costs)
        edges = np.array(edges+virtual_edges)

    vertices, edges = pcst_fast(edges, prizes, costs, root, num_clusters, pruning, verbosity_level)

    selected_nodes = vertices[vertices < graph.num_nodes]
    selected_edges = [mapping_e[e] for e in edges if e < num_edges]
    virtual_vertices = vertices[vertices >= graph.num_nodes]
    if len(virtual_vertices) > 0:
        virtual_vertices = vertices[vertices >= graph.num_nodes]
        virtual_edges = [mapping_n[i] for i in virtual_vertices]
        selected_edges = np.array(selected_edges+virtual_edges)

    edge_index = graph.edge_index[:, selected_edges]
    selected_nodes = np.unique(np.concatenate([selected_nodes, edge_index[0].numpy(), edge_index[1].numpy()]))

    n = textual_nodes.iloc[selected_nodes]
    e = textual_edges.iloc[selected_edges]
    desc = n.to_csv(index=False)+'\n'+e.to_csv(index=False, columns=['src', 'edge_attr', 'dst'])

    mapping = {n: i for i, n in enumerate(selected_nodes.tolist())}

    x = graph.x[selected_nodes]
    edge_attr = graph.edge_attr[selected_edges]
    src = [mapping[i] for i in edge_index[0].tolist()]
    dst = [mapping[i] for i in edge_index[1].tolist()]
    edge_index = torch.LongTensor([src, dst])
    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, num_nodes=len(selected_nodes))

    return data, desc

In [51]:
query = "What is the court decision involving united states gypsum?"

q_emb = embedding_fn(query)

INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


In [52]:
q_emb = torch.tensor(q_emb[0], dtype=torch.float) 

In [53]:
retrieved_data, description = retrieval_via_pcst(data, q_emb, node_features, edge_features)

IndexError: index 329 is out of bounds for dimension 0 with size 329

In [55]:
from pyvis.network import Network

net = Network(
    notebook=False,
    bgcolor="#1a1a1a",
    cdn_resources="remote",
    height="900px",
    width="100%",
    select_menu=True,
    font_color="#cccccc",
    filter_menu=False,
)
net.from_nx(G)
net.force_atlas_2based(central_gravity=0.015, gravity=-31)
net.show_buttons(filter_=["physics"])
html_output_path = os.path.join("data", "index.html")
html = net.generate_html()
with open(html_output_path, mode="w", encoding="utf-8-sig") as fp:
    fp.write(html)
net.show(html_output_path, notebook=False)

data\index.html


In [56]:
print(dfg_final.shape)
dfg_final.head()

(648, 5)


Unnamed: 0,node_1,node_2,chunk_id,edge,count
0,"$984,943.15",Hartford,"0d1336884d5b4d1b9f43d4144b3e7e67,0d1336884d5b4...",chunk contextual proximity,7.0
1,"$984,943.15",Hartford Casualty,"0d1336884d5b4d1b9f43d4144b3e7e67,0d1336884d5b4...",chunk contextual proximity,2.0
2,"$984,943.15","January 17, 1995","0d1336884d5b4d1b9f43d4144b3e7e67,0d1336884d5b4...",chunk contextual proximity,2.0
3,"$984,943.15",Konami,"0d1336884d5b4d1b9f43d4144b3e7e67,0d1336884d5b4...",chunk contextual proximity,17.0
4,"$984,943.15",Land & Sky,"0d1336884d5b4d1b9f43d4144b3e7e67,0d1336884d5b4...",chunk contextual proximity,2.0


In [57]:
from src.embedding_models.models import OpenAIEmbeddings

embeddings = OpenAIEmbeddings()

In [58]:
embeddings = OpenAIEmbeddings()

In [59]:
pd.DataFrame(nodes)

Unnamed: 0,0
0,"$984,943.15"
1,935 F. Supp. at 1116
2,Advertising mode
3,"April 16, 1999"
4,"April 23, 1996"
5,Brochu
6,CGL policy
7,Complaint
8,Device Insertion
9,Du Page County


In [106]:
from openai import AsyncOpenAI
import asyncio
from tqdm.asyncio import tqdm_asyncio

class NodeEntity(BaseModel):
    """The original and resolved entity name from a list of nodes."""
    
    original_name: str = Field(
        ...,
        description="The original entity given by the user.",
    )
    resolved_name: str = Field(
        ...,
        description="The name of the entity such that duplications are resolved and issues with punctuation or capitalization are corrected.",
    )
    
    
class ResolvedEntities(BaseModel):
    """A list of entities."""
    
    entities: List[NodeEntity] = Field(..., default_factory=list)
    
    @property
    def to_pandas(self):
        entity_dict = {
            "original_name": [n.original_name for n in self.entities],
            "resolved_name": [n.resolved_name for n in self.entities],
        }
        return pd.DataFrame(entity_dict)
    
    
def system_message() -> str:
    return (
        "You are an expert entity resolution AI. "
        "Users will provide you with a list of Node names representing entities from a knowledge graph separated by new lines. "
        "Your task is to, for each entity, generate a NodeEntity such that duplicates are removed by determining their resolved name. "
        "Try to infer the base name that uniquely describes the entity as concisely as possible. "
        "For example, 'Liberty Mutual Group', 'Liberty Mutual Insurance Co', 'Liberty Mutual Company of Massachusets', 'Liberty Mutual Casualty Insurance' "
        "should all simply be 'Liberty Mutual'. "
        "Please also correct any punctuation or capitalization issues."
    )


async def resolve_entities(entities: List[str], model: str = 'gpt-4o') -> str:
    entity_list_string = "\n ".join(entities)
    client = instructor.from_openai(AsyncOpenAI())
    return await client.chat.completions.create(
        model=model,
        max_retries=5,
        messages=[
            {
                "role": "system",
                "content": system_message(),
            },
            {
                "role": "user",
                "content": f"Here is the list of entities to resolve:\n\n{entity_list_string}",
            },
        ],
        response_model=ResolvedEntities,
    )
  
  
async def process_batch(df_batch: pd.DataFrame, column_name: str, model: str) -> pd.DataFrame:
    entities = df_batch[column_name].tolist()
    resolved_entities = await resolve_entities(entities, model)
    resolved_df = resolved_entities.to_pandas
    resolved_df.index = df_batch.index
    return df_batch.join(resolved_df.set_index("original_name"), on=column_name)


async def process_dataframe(df: pd.DataFrame, column_name: str, model: str, batch_size: int = 50) -> pd.DataFrame:
    tasks = []
    for start in range(0, len(df), batch_size):
        df_batch = df.iloc[start:start + batch_size]
        tasks.append(process_batch(df_batch, column_name, model))
    
    results = []
    for f in tqdm_asyncio.as_completed(tasks, desc="Processing batches"):
        result = await f
        results.append(result)
    
    return pd.concat(results, ignore_index=True)

# Example usage
# import nest_asyncio
# nest_asyncio.apply()

# column_name = "node"
# model = "gpt-4o"

# resolved_df = asyncio.run(process_dataframe(df, column_name, model))

In [76]:
resolved_entity_nodes = resolve_entities(nodes.tolist())

INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


In [97]:
test_df = pd.DataFrame({'node': nodes})
test_df.head()

Unnamed: 0,node
0,"$984,943.15"
1,935 F. Supp. at 1116
2,Advertising mode
3,"April 16, 1999"
4,"April 23, 1996"


In [110]:
import nest_asyncio
nest_asyncio.apply()

column_name = "node"
model = "gpt-4o"

resolved_df = asyncio.run(process_dataframe(test_df, column_name, model))

Processing batches:   0%|          | 0/2 [00:00<?, ?it/s]INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
Processing batches:  50%|█████     | 1/2 [00:03<00:03,  3.47s/it]INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
Processing batches: 100%|██████████| 2/2 [00:12<00:00,  6.41s/it]


In [111]:
resolved_df

Unnamed: 0,node,resolved_name
0,"making, using, or selling of a patented invention","Making, Using, or Selling of a Patented Invention"
1,opinion of the court,Opinion of the Court
2,patent infringement,Patent Infringement
3,patent infringement lawsuit,Patent Infringement
4,piracy,Piracy
5,sale of a patented component,Sale of a Patented Component
6,section 155,Section 155
7,summary judgment,Summary Judgment
8,trial,Trial
9,"$984,943.15",984943.15


In [5]:
docs = [Document(text=t) for t in example_text_list]

In [6]:
graph_maker = GraphMaker(ontology=example_ontology, verbose=True)

In [7]:
graph = graph_maker.from_documents(
    docs[:3], 
    delay_s_between=1,
    ) 

[92m[39m
[92m▶︎ GRAPH MAKER LOG - 2024-05-10 16:39:48 - INFO [39m
[92mDocument: 1[39m
[92m[39m
[34m[39m
[34m▶︎ GRAPH MAKER VERBOSE - 2024-05-10 16:39:48 - INFO [39m
[34mUsing Ontology:
labels=[{'Person': 'Person name without any adjectives, Remember a person may be referenced by their name or using a pronoun'}, {'Object': "Do not add the definite article 'the' in the object name"}, {'Event': 'Event event involving multiple people. Do not include qualifiers or verbs like gives, leaves, works etc.'}, 'Place', 'Document', 'Organization', 'Action', {'Miscellaneous': 'Any important concept can not be categorized with any other given label'}] relationships=['Relation between any pair of Entities'][39m
[34m[39m
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
[34m[39m
[34m▶︎ GRAPH MAKER VERBOSE - 2024-05-10 16:40:22 - INFO [39m
[34mLLM Response:
[
   {
       "node_1": {"label": "Person", "name": "Bilbo Baggins"},
       "node_2":

In [9]:
graph[0]

[Edge(node_1=Node(label='Person', name='Bilbo Baggins'), node_2=Node(label='Event', name='birthday'), relationship='Bilbo Baggins celebrates his birthday.', metadata=None, order=0),
 Edge(node_1=Node(label='Person', name='Bilbo Baggins'), node_2=Node(label='Person', name='Frodo'), relationship='Bilbo Baggins leaves the Ring to Frodo, who is his heir.', metadata=None, order=0),
 Edge(node_1=Node(label='Person', name='Gandalf'), node_2=Node(label='Object', name='Ring'), relationship='Gandalf suspects and later confirms that the Ring is a Ring of Power.', metadata=None, order=0),
 Edge(node_1=Node(label='Person', name='Ring'), node_2=Node(label='Person', name='Dark Lord Sauron'), relationship='The Ring was lost by Dark Lord Sauron.', metadata=None, order=0),
 Edge(node_1=Node(label='Person', name='Gandalf'), node_2=Node(label='Person', name='Frodo'), relationship='Gandalf counsels Frodo to take the Ring away from the Shire and promises to return.', metadata=None, order=0),
 Edge(node_1=No

___

# instructor KG with iterative updates

In [11]:
from pydantic import BaseModel, Field
from typing import List


class Node(BaseModel):
    id: int
    label: str
    color: str
    
    def __hash__(self) -> int:
        return hash((id, self.label))

class Edge(BaseModel):
    source: int
    target: int
    label: str
    color: str = "black"
    
    def __hash__(self) -> int:
        return hash((self.source, self.target, self.label))

class KnowledgeGraph(BaseModel):
    nodes: List[Node] = Field(..., default_factory=list)
    edges: List[Edge] = Field(..., default_factory=list)

In [12]:
from openai import OpenAI
import instructor

client = instructor.from_openai(OpenAI())

def generate_graph(input) -> KnowledgeGraph:
    return client.chat.completions.create(
        model="gpt-3.5-turbo",
        messages=[
            {
                "role": "user",
                "content": f"Help me understand the following by describing it as a detailed knowledge graph: {input}",
            }
        ],
        response_model=KnowledgeGraph,
    )

In [9]:
graph_2 = generate_graph(input=example_text_list[0])

INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


In [11]:
from graphviz import Digraph

def visualize_knowledge_graph(kg: KnowledgeGraph):
    dot = Digraph(comment="Knowledge Graph")

    # Add nodes
    for node in kg.nodes:
        dot.node(str(node.id), node.label, color=node.color)

    # Add edges
    for edge in kg.edges:
        dot.edge(str(edge.source), str(edge.target), label=edge.label, color=edge.color)

    # Render the graph
    dot.render("knowledge_graph.gv", view=True)

In [12]:
visualize_knowledge_graph(graph_2)

In [10]:
from graphviz import Digraph
from typing import Optional
from pydantic import BaseModel, Field
from typing import List

class Node(BaseModel):
    id: int
    label: str
    color: str
    
    def __hash__(self) -> int:
        return hash((id, self.label))

class Edge(BaseModel):
    source: int
    target: int
    label: str
    color: str = "black"
    
    def __hash__(self) -> int:
        return hash((self.source, self.target, self.label))


class KnowledgeGraph(BaseModel):
    nodes: Optional[List[Node]] = Field(..., default_factory=list)
    edges: Optional[List[Edge]] = Field(..., default_factory=list)

    def update(self, other: "KnowledgeGraph") -> "KnowledgeGraph":
        """Updates the current graph with the other graph, deduplicating nodes and edges."""
        return KnowledgeGraph(
            nodes=list(set(self.nodes + other.nodes)),
            edges=list(set(self.edges + other.edges)),
        )

    def draw(self, prefix: str = None):
        dot = Digraph(comment="Knowledge Graph")

        for node in self.nodes:  
            dot.node(str(node.id), node.label, color=node.color)

        for edge in self.edges:  
            dot.edge(
                str(edge.source), str(edge.target), label=edge.label, color=edge.color
            )
        dot.render(prefix, format="png", view=True)

In [8]:
from openai import OpenAI
import instructor

client = instructor.from_openai(OpenAI())

def generate_graph(input: List[str]) -> KnowledgeGraph:
    cur_state = KnowledgeGraph()  
    num_iterations = len(input)
    for i, inp in enumerate(input):
        new_updates = client.chat.completions.create(
            model="gpt-3.5-turbo-16k",
            messages=[
                {
                    "role": "system",
                    "content": """You are an iterative knowledge graph builder.
                    You are given the current state of the graph, and you must append the nodes and edges
                    to it Do not provide any duplicates and try to reuse nodes as much as possible.""",
                },
                {
                    "role": "user",
                    "content": f"""Extract any new nodes and edges from the following:
                    # Part {i}/{num_iterations} of the input:

                    {inp}""",
                },
                {
                    "role": "user",
                    "content": f"""Here is the current state of the graph:
                    {cur_state.model_dump_json(indent=2)}""",
                },  
            ],
            response_model=KnowledgeGraph,
        )  # type: ignore

        # Update the current state
        cur_state = cur_state.update(new_updates)  
        cur_state.draw(prefix=f"iteration_{i}")
    return cur_state

In [11]:
graph: KnowledgeGraph = generate_graph(example_text_list[:3])
graph.draw(prefix="final")

INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
