In [1]:
import re
import logging

import sys

sys.path[0] = sys.path[0] + "/../"

from src.prompts import SYSTEM_PROMPT, ALIGN_PROMPT, DEEP_RELS, QA_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT, MERGE_PROMPT
from src.llm import LLM
from src.text_extract import (
    get_nodes_relationships_from_rawtext,
    nodes_rels_combine_text,
    merge_nodes_rels,
)
from src.tools import encode_image
from src.dataclass import Chunk, Relationship, Entity, Image
from src.multimodal.img import (
    extract_images,
    extract_entity_rels_images,
    merge_entity_rels,
)

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(message)s",
    stream=open("log.log", "w"),
)

In [2]:
def split_document(
        document: str, chunk_size: int = 1000, over_lap: float = 0.1
) -> list[str]:
    from langchain.text_splitter import MarkdownTextSplitter

    text_splitter = MarkdownTextSplitter(
        chunk_size=chunk_size, chunk_overlap=int(chunk_size * over_lap)
    )
    chunks = text_splitter.split_text(document)
    return chunks


def _extract_images_paths(document: str) -> list[str]:
    """
    提取文档中的图片路径和图片标题。
    """
    images = re.findall(r"!\[.*?\]\((.*?)\)", document)
    return images

class LLMForEntityExctract:
    def __init__(self, initial_system_prompt):
        self.conversation_history = [
            {"role": "system", "content": initial_system_prompt},
        ]
        self.llm = LLM()

    def add_message_and_call_llm(self, new_message, model="gpt-4o-mini"):
        self.conversation_history.append({"role": "user", "content": new_message})
        res = self.llm.chat(self.conversation_history, callback=None, model=model)
        # assistant_message = res["choices"][0]["message"]["content"]
        # self.conversation_history.append({"role": "assistant", "content": assistant_message})
        self.conversation_history.append({"role": "assistant", "content": res})
        return res

def extract_entities_relations(
        chunk: Chunk, prompt
) -> tuple[list["Entity"], list["Relationship"]]:
    """
    使用LLM提取chunk中的实体及关系。
    """

    max_gleanings = 2   # 最大迭代次数
    # messages = [
    #     {"role": "system", "content": prompt},
    #     {"role": "user", "content": chunk.text},
    # ]
    # llm = LLM()
    # # callback = lambda x: print(x, end="")
    # res = llm.chat(messages, callback=None, model="gpt-4o-mini")
    # logging.info(f"Extracted entities and relationships from chunk.llm res:\n {res}")

    llm = LLMForEntityExctract(prompt)
    res = llm.add_message_and_call_llm(chunk.text)
    logging.info(f"Extracted entities and relationships from chunk.llm: init iteration res:\n {res}")

    for i in range(max_gleanings):
        res = llm.add_message_and_call_llm(CONTINUE_PROMPT)
        logging.info(f"Extracted entities and relationships from chunk.llm: {i} iteration res:\n {res}")
        # results += response.output or ""

        # if this is the final glean, don't bother updating the continuation flag
        if i >= max_gleanings - 1:
            break

        continue_res = llm.add_message_and_call_llm(LOOP_PROMPT)
        if continue_res != "YES":
            break

    nodes_rels = get_nodes_relationships_from_rawtext(res)
    entities = [Entity.from_dict(node) for node in nodes_rels["nodes"]]
    relationships = [Relationship.from_dict(rel) for rel in nodes_rels["relationships"]]
    for entity in entities:
        if entity.properties.get("references"):
            entity.references = entity.properties["references"]
            del entity.properties["references"]

    for rel in relationships:
        if rel.properties.get("references"):
            rel.references = rel.properties["references"]
            del rel.properties["references"]

    """
    images: list[Image] = extract_images(chunk)
    images = _extract_images_paths(chunk.text)
    merged_node_rels = nodes_rels

    for img in images:
        logging.info(f"Processing image: {img}")
        content = (
            IMAGE_PROMPT
            + "\nBelow is a list of entities and relationships extracted from the context of the image:\n"
            + nodes_rels_combine_text(merged_node_rels)
        )
        base64_image = encode_image(img)
        messages = [
            {"role": "system", "content": content},
            {
                "role": "user",
                "content": [
                    {
                        "type": "image_url",
                        "image_url": {"url": base64_image},
                    }
                ],
            },
        ]
        logging.info(f"Sending image to LLM messages: {messages}")
        images_nodes_rels_text = llm.chat(messages, callback=None, model="gpt-4o-mini")
        logging.info(
            f"Extracted entities and relationships from image.llm res:\n {images_nodes_rels_text}"
        )
        image_nodes_rels = get_nodes_relationships_from_rawtext(images_nodes_rels_text)
        for node in image_nodes_rels["nodes"]:
            if node["properties"]["images"]:
                node["properties"]["images"].append(img)
            else:
                node["properties"]["images"] = [img]
        for rel in image_nodes_rels["relationships"]:
            if rel["properties"]["images"]:
                rel["properties"]["images"].append(img)
            else:
                rel["properties"]["images"] = [img]
        merged_node_rels: dict[str, list] = merge_nodes_rels(
            merged_node_rels, image_nodes_rels
        )
    """
    return entities, relationships


def align_entities_relations(
        previous_result: dict[str, list], current_chunk: dict[str, list]
) -> dict[str, list]:
    """
    将上一个chunk的结果与当前chunk的结果一起发送给LLM，要求其完成实体及关系对齐。
    """
    prompt = (
            "Previous entities and relationships:\n"
            + nodes_rels_combine_text(previous_result)
            + "\nCurrent entities and relationships:"
            + nodes_rels_combine_text(current_chunk)
    )
    messages = [
        {"role": "system", "content": ALIGN_PROMPT},
        {"role": "user", "content": prompt},
    ]
    llm = LLM()
    res = llm.chat(messages, callback=None, model="gpt-4o-mini")
    logging.info(f"Aligned entities and relationships.llm res:\n {res}")
    return get_nodes_relationships_from_rawtext(res)


def dig_deep_relationships(nodes_rels: dict[str, list]):
    messages = [
        {"role": "system", "content": DEEP_RELS},
        {"role": "user", "content": nodes_rels_combine_text(nodes_rels)},
    ]
    llm = LLM()
    llm_res = llm.chat(messages, callback=None, model="gpt-4o-mini")
    logging.info(f"Deep relationships.llm res:\n {llm_res}")
    return get_nodes_relationships_from_rawtext(llm_res)

In [3]:
from transformers import BertTokenizer, BertModel
import torch

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

def get_text_embedding(text):
    """
    Use the BERT model to generate the embeddings
    """
    inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=128)
    with torch.no_grad():
        outputs = model(**inputs)
    return outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
 



In [4]:
from sklearn.neighbors import NearestNeighbors
import numpy as np

def build_k_nearest_graph(entities, k=5):
    """
    Build the KNN Graph
    """
    embeddings = np.array([get_text_embedding(entity.name) for entity in entities])
    nbrs = NearestNeighbors(n_neighbors=k, metric='cosine').fit(embeddings)
    distances, indices = nbrs.kneighbors(embeddings)

    graph = {}
    for i, entity in enumerate(entities):
        graph[entity.name] = []
        for idx in indices[i]:
            if idx != i:  # Not connected to self
                graph[entity.name].append((entities[idx], distances[i][indices[i] == idx][0]))
    return graph


In [5]:
import networkx as nx

def filter_weakly_connected_components(graph, distance_threshold=0.2):
    """
    Filter weakly connected
    """
    G = nx.Graph()
    for entity, neighbors in graph.items():
        for neighbor, distance in neighbors:
            if distance < distance_threshold:
                G.add_edge(entity, neighbor.name)

    weak_components = list(nx.connected_components(G))
    return weak_components


In [29]:
def merge_entities(entities, components):
    """
    Merge entities
    """
    merged_entities = []
    original_to_merged_map = {}  # record the mapping of the old name to the new name

    # Convert entities list to a dic
    if isinstance(entities, list):
        entities = {entity.name: entity for entity in entities}

    processed_entities = set()  # record processed entities
    
    for component in components:
        entity_names = list(component)
        content = "Should the following entities be merged?\n" + "\n".join(entity_names)
        messages = [
            {"role": "system", "content": MERGE_PROMPT},
            {"role": "user", "content": content},
        ]
        llm = LLM()
        llm_res = llm.chat(messages, callback=None, model="gpt-4o-mini")
        logging.info(f"llm_res: {llm_res}")
        print(f"llm_res: {llm_res}")
        if llm_res.strip().lower() == "yes":
            # merge entity
            primary_entity = entities[entity_names[0]]
            for name in entity_names[1:]:
                entity = entities[name]
                primary_entity.references.extend(entity.references)
                primary_entity.properties.update(entity.properties)
                primary_entity.images.extend(entity.images)
                # record mapping of old name to new name
                original_to_merged_map[name] = primary_entity.name
                # record processed entities
                processed_entities.add(name)

            original_to_merged_map[primary_entity.name] = primary_entity.name 
            merged_entities.append(primary_entity)
            processed_entities.add(primary_entity.name)
        else:
            # not merged
            for name in entity_names:
                original_to_merged_map[name] = name
                processed_entities.add(name)
            merged_entities.extend([entities[name] for name in entity_names])

    # entities not in components
    for entity_name, entity in entities.items():
        if entity_name not in processed_entities:
            merged_entities.append(entity)
            original_to_merged_map[entity_name] = entity_name
    return merged_entities, original_to_merged_map
    

In [23]:
def update_relationships(relationships, merged_entities, original_to_merged_map):
    """
    Update relationships
    """
    name_map = {entity.name: entity for entity in merged_entities}

    updated_relationships = []
    for rel in relationships:
        # use original_to_merged_map to find updated entity name
        start_name = original_to_merged_map.get(rel.start, rel.start)
        end_name = original_to_merged_map.get(rel.end, rel.end)

        # find entity through name
        start = name_map.get(start_name, start_name)
        end = name_map.get(end_name, end_name)

        # create updated relationships
        updated_relationships.append(Relationship(
            start=start.name if isinstance(start, Entity) else start,
            end=end.name if isinstance(end, Entity) else end,
            type=rel.type,
            references=rel.references,
            properties=rel.properties,
            images=rel.images
        ))

    return updated_relationships


In [8]:
from dotenv import load_dotenv
load_dotenv("../.env")

entities = [
    Entity(name="Company A", label="Company", references=["ref1"], properties={"location": "NY"}, images=["image1"]),
    Entity(name="Company B", label="Company", references=["ref2"], properties={"location": "SF"}, images=["image2"]),
    Entity(name="Company_B", label="Company", references=["ref3"], properties={"location": "SF"}, images=["image3"]),
    Entity(name="Company C", label="Company", references=["ref4"], properties={"location": "LA"}, images=["image4"]),
    Entity(name="NYC", label="City", references=["ref5"], properties={"population": "8M"}, images=["image5"]),
    Entity(name="New York City", label="City", references=["ref6"], properties={"population": "8M"}, images=["image6"]),
]

relationships = [
    Relationship(start="Company A", end="Company B", type="partnership", references=["rel_ref1"], properties={}, images=[]),
    Relationship(start="Company_B", end="Company C", type="acquisition", references=["rel_ref2"], properties={}, images=[]),
    Relationship(start="NYC", end="New York City", type="connection", references=["rel_ref3"], properties={}, images=[]),
]



In [9]:
# build knn graph
graph = build_k_nearest_graph(entities)

In [35]:
graph['Company B']

[(Entity(name='Company C', label='Company', references=['ref4'], properties={'location': 'LA'}, images=['image4']),
  0.038917303),
 (Entity(name='Company A', label='Company', references=['ref1'], properties={'location': 'NY'}, images=['image1']),
  0.20522702),
 (Entity(name='Company_B', label='Company', references=['ref3'], properties={'location': 'SF'}, images=['image3']),
  0.21149576),
 (Entity(name='NYC', label='City', references=['ref5'], properties={'population': '8M'}, images=['image5']),
  0.34885496)]

In [11]:
# Filter weakly connected components
components = filter_weakly_connected_components(graph)

In [12]:
components

[{'Company A', 'Company B', 'Company C'}]

In [30]:
merged_entities, original_to_merged_map = merge_entities(entities, components)

llm_res: no


In [31]:
merged_entities

[Entity(name='Company B', label='Company', references=['ref2'], properties={'location': 'SF'}, images=['image2']),
 Entity(name='Company C', label='Company', references=['ref4'], properties={'location': 'LA'}, images=['image4']),
 Entity(name='Company A', label='Company', references=['ref1'], properties={'location': 'NY'}, images=['image1']),
 Entity(name='Company_B', label='Company', references=['ref3'], properties={'location': 'SF'}, images=['image3']),
 Entity(name='NYC', label='City', references=['ref5'], properties={'population': '8M'}, images=['image5']),
 Entity(name='New York City', label='City', references=['ref6'], properties={'population': '8M'}, images=['image6'])]

In [32]:
original_to_merged_map

{'Company B': 'Company B',
 'Company C': 'Company C',
 'Company A': 'Company A',
 'Company_B': 'Company_B',
 'NYC': 'NYC',
 'New York City': 'New York City'}

In [33]:
updated_relationships = update_relationships(relationships, merged_entities, original_to_merged_map)

In [34]:
updated_relationships

[Relationship(start='Company A', end='Company B', type='partnership', references=['rel_ref1'], properties={}, images=[]),
 Relationship(start='Company_B', end='Company C', type='acquisition', references=['rel_ref2'], properties={}, images=[]),
 Relationship(start='NYC', end='New York City', type='connection', references=['rel_ref3'], properties={}, images=[])]