In [None]:
import os
import ast
import json
import numpy as np
from utils.utils import TagTree, save_dict_to_file, load_dict_from_file
from utils.logger import setup_logger
import utils.vector_db as vector_db

import logging
import Levenshtein
import re
from tqdm import tqdm_notebook as tqdm
from collections import defaultdict, Counter

import nltk
from nltk.stem import WordNetLemmatizer
from nltk.corpus import stopwords, wordnet
from nltk.tokenize import word_tokenize
from nltk.stem import WordNetLemmatizer
lemmatizer = WordNetLemmatizer()
import jieba
from openai import OpenAI


from pathlib import Path
from tqdm import tqdm_notebook as tqdm

from sentence_transformers import SentenceTransformer

from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.metrics import silhouette_score
from sklearn.metrics.pairwise import cosine_similarity
import hashlib
import matplotlib.pyplot as plt

In [4]:
os.environ["OPENAI_API_KEY"] = "your_api_key"
DEFAULT_MODEL = "Meta-Llama-3-1-70B-Instruct"
DATA_PATH = "your_document_path"
OUTPUT_PATH = "./output/test_data"

if not os.path.exists(DATA_PATH): exit(0)
if not os.path.exists(OUTPUT_PATH): os.makedirs(OUTPUT_PATH)

In [None]:
logging.basicConfig(level=logging.INFO)
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)

log_file_path = f'{OUTPUT_PATH}/entity.log'
logger = setup_logger('logger', log_file_path, overwrite=True)

In [None]:
from sentence_transformers import SentenceTransformer

embed_model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2', device="cuda:1")

In [None]:
from functools import wraps
import time
import requests

#解决限流问题
def traffic_limit(max_qpm, max_qps):
    def decorator(func):
        qps_stack = []
        qpm_stack = []

        @wraps(func)
        def wrapper(*args, **kwargs):
            now = time.time()
            while qps_stack and now - qps_stack[0] > 1:
                qps_stack.pop(0)
            while qpm_stack and now - qpm_stack[0] > 60:
                qpm_stack.pop(0)        
                
            # 检查当前调用是否超过了限制
            if len(qps_stack) >= max_qps:
                print("waiting for QPS control")
                time.sleep(1.1)                
            if len(qpm_stack) >= max_qpm:
                try:
                    sleep_time = qpm_stack[0]+60-time.time()+5
                    print(f"waiting for QPM Control: {sleep_time}s")
                    time.sleep(sleep_time)
                except: time.sleep(60)
                
            qps_stack.append(time.time())
            qpm_stack.append(time.time())
            return func(*args, **kwargs)
        return wrapper
    return decorator

@traffic_limit(20, 1)
def request_model(prompt, context=None, model=DEFAULT_MODEL, temperature=0.1, max_tokens=30000):
    prompt = prompt[:max_tokens]
    msgs = [{"role":"system",   "content": "The user may ask questions in Chinese or English. Please strictly follow the user’s required output format."},
        {"role": "user", "content": prompt}]
    client = OpenAI(api_key=os.environ['OPENAI_API_KEY'])
    
    if context is not None:
        context += msgs
        msgs = context
    
    extend_fields= {"top_k": 1, "max_new_tokens":1000}
    try:
        #print(msgs)
        response = client.chat.completions.create(
            model=model,
            messages=msgs,
            stream=False,
            temperature=temperature,
            max_tokens=max_tokens,
            timeout=Timeout(120, 20),
            top_p=0.8,
            extend_fields=extend_fields)
        
        if context is not None: context.append({"role":"assistant", "content": response.choices[0].message.content})
        #print(response.error_code)
        if response.error_code is None:
            resp_text = json.dumps(response.choices[0].message.content, ensure_ascii=False)
            think_pattern = r'<think>.*?</think>'
            clean_text = re.sub(think_pattern, '', resp_text)
            return clean_text
        else:
            logger.error( f'request fail: {response.error_code}')
            raise
            
    except: 
        return None

In [None]:
###util functions
def lsdir(dir_path):
    dir_subs = os.listdir(dir_path)
    return [item for item in dir_subs if not item.startswith('.')]

def count_file_num(KB_dir):
    total_files = 0
    for dirpath, dirnames, filenames in os.walk(KB_dir):
        dirnames[:] = [d for d in dirnames if not d.startswith('.')]
        total_files += len(filenames)
    return total_files

def check_key_existence(KB_dir, json_dir):
    json_file = load_dict_from_file(json_dir)
    for dirpath, dirnames, filenames in os.walk(KB_dir):
        dirnames[:] = [d for d in dirnames if not d.startswith('.')]
        for fname in filenames:
            file_path = os.path.join(dirpath, fname)
            if os.path.abspath(file_path) not in json_file.keys(): print(file_path)
    return

def find_best_match(a_element, b_list):
    best_match = None
    best_similarity = -1  
    for b_element in b_list:
        distance = Levenshtein.distance(a_element, b_element)
        similarity = 1 - (distance / max(len(a_element), len(b_element)))
        
        if similarity > best_similarity:
            best_similarity = similarity
            best_match = b_element
            
    return best_match

def parse_model_resp(resp_type):
    def decorator(func):
        def wrapper(*args, **kwargs):
            resp = func(*args, **kwargs)
            
            def extract_formatted_data(text):
                try:
                    sections = text.split('###')
                except: raise
                    
                if len(sections) < 2:
                    raise ZeroDivisionError
                
                formatted_text = sections[1].replace('\\n', '').replace('\r', '')
                formatted_text = formatted_text.replace("\\", "")
                formatted_text = formatted_text.replace("“", "\"").replace("”", "\"").replace("：",":").replace("，", ",").replace("；", ";").replace(";\"", "\"")
                token_count = len(word_tokenize(formatted_text))
                if token_count > 3000:  return ""
                #print(formatted_text)
                if resp_type == "list":
                    try:
                        def ensure_list_quotes(input_str):
                            input_str = input_str.strip()[1:-1]
                            elements = [element.strip() for element in input_str.split(',')]
                            
                            quoted_elements = [
                                element if (element.startswith('"') and element.endswith('"')) or \
                                           (element.startswith("'") and element.endswith("'"))
                                else f'"{element}"'
                                for element in elements
                            ]
                            

                            result = '[' + ','.join(quoted_elements) + ']'
                            return result
                        result = ast.literal_eval(ensure_list_quotes(formatted_text))
                        return result
                    except (SyntaxError, ValueError) as e:
                        logger.error("Error parsing model response:", e, formatted_text)
                        raise
                        
                elif resp_type == "str":
                    return formatted_text.replace("\"","").replace("\'","")
                    
                elif resp_type == "original_str":
                    return formatted_text
                    
                elif resp_type == "dict":
                    try:
                        #print(formatted_text)
                        #formatted_text =  re.sub(r'^"(.*)"$', r'\'\1\'', formatted_text)
                        result = ast.literal_eval(formatted_text)
                        return result
                    except Exception as e:
                        logger.error("Error parsing model response:", e, formatted_text)
                        raise

            # 调用内部函数并返回结果
            parsed_response = extract_formatted_data(resp)
            logger.info(parsed_response)
            return parsed_response

        return wrapper

    return decorator

@parse_model_resp(resp_type="list")
def sort_dir(dir_path, context=None):
    dir_subs = lsdir(dir_path)    
    prompt = f"""
    Currently, there is a document represented as a list: {dir_subs}
    -----------------------------------
    If you need to read these documents in sequence to build a knowledge graph, and when processing subsequent documents, you can refer to the previous ones, what order do you think would best utilize the existing knowledge?
    Note: If there is only one document, return the original document as is. If there are multiple documents, you should prioritize reading summary or overview documents first, followed by documents that provide specific background knowledge, and finally documents that involve specific operational procedures.
    If there is historical dialogue context, you need to consider why the previous responses in the conversation were incorrect and then answer according to the requirements.
    -----------------------------------
    Provide your reasoning process and return the following format with priorities sorted from highest to lowest, marked by ###:
    Note only the list should be marked by ###, but not the reasoning process.
    ###
    [Returned priority list]
    ###
    Example output format:
    ###['A','B']###
    """
    logger.info(prompt)
    resp =  request_model(prompt, context)
    #print(resp)
    return resp


def refine_subnames(base, dir_subs):
    true_subnames = os.listdir(base)
    def replace_with_best_matches(A, B):
        replaced_list = [find_best_match(a, B) for a in A]
        return replaced_list
        
    return replace_with_best_matches(dir_subs, true_subnames)

@parse_model_resp(resp_type="dict")
def test_format(s):
    return s


In [None]:
@parse_model_resp(resp_type="original_str")
def summarize_no_ref(doc_path):
    title = os.path.basename(doc_path)
    with open(doc_path, "r") as f:
        doc_content = f.read()

    prompt = f"""
        ----------------Task Requirements-------------------
        
        Summarize the content of the current input document in one sentence, no more than 100 words. Answer in English.
        
        ----------------Current Input Document-------------------
        
        Title: {title}
        
        Content: {doc_content}
        
        ---------------Output Format Requirements--------------------
        
        Return a string marked with ### at the beginning and end, and the string must not contain single or double quotes. The format is as follows:
        ###
        Summary content
        ###
    """
    logger.info(prompt)
    return request_model(prompt)


#doc:待总结文档path
#ref_sums: {path: sum, }
def summarize_with_ref(doc_path, vecDB_path, max_context=20):
    title = os.path.basename(doc_path)
    with open(doc_path, "r") as f:
        doc_content = f.read()
        #doc_content = "\n".join(line for line in doc_content.splitlines() if line.strip() != "")
    
    
    #每次从tmp vector DB检索
    if not os.path.exists(vecDB_path): ref_sums=[]
    else: ref_sums = vector_db.query(doc_content, max_context, embed_model, vecDB_path)
        
    ref_context = ""
    for i, ref in enumerate(ref_sums):
        #print(ref_path)
        ref_context += f"{i+1}. {ref} \n"
    
    prompt = f"""
    ----------------Task Requirements-------------------
    
    Briefly summarize the content of the current input document, answering in the same language as the input text.
    
    ---------------Output Format Requirements--------------------
    
    Return a string that must not contain single or double quotes. The format is as follows:

    ----------------Current Input Document-------------------
    
    Title: {title}
    
    Content: {doc_content}
    
    ---------------Background Knowledge-------------------
    The input document is located in a subdirectory of the following document, and all entities mentioned within the document are within the context of the background knowledge. The background knowledge may include domain introductions, terminology definitions, operational methods, etc. Please consider this background knowledge when summarizing the input document. Note that you should summarize the current input document, referencing the background knowledge, rather than treating the background knowledge as the main document.
    Background knowledge:
    {ref_context}
    """
    logger.info(prompt)
    return request_model(prompt)


#输入文件夹根目录， 保存目录, 总结向量存储目录
def summarize_KB_docs(KB_dir, save_dir, max_context=20):
    total_files = 0
    accessed_files = 0
    for dirpath, dirnames, filenames in os.walk(KB_dir):
        dirnames[:] = [d for d in dirnames if not d.startswith('.')]
        total_files += len(filenames)
    
    docs_summary = {}
    vecDB_path = os.path.join(os.path.dirname(save_dir), "summary_faiss_index.bin")
    if os.path.exists(vecDB_path): os.remove(vecDB_path)
    vecDB_dir = os.path.dirname(vecDB_path)
    if not os.path.exists(vecDB_dir):
        os.makedirs(vecDB_dir)
        
    def do_summarize(base_dir, filetype=".md"):
        nonlocal accessed_files
        def equal_to_list(ele, lst):
            if not isinstance(ele, list):
                return False
            return all(ele.count(x) == lst.count(x) for x in set(lst))
            
        if lsdir(base_dir) == []: return 
        logger.info("\n-------------------")
        logger.info(f"base dir: {base_dir}")
        sort_context = []
        dir_subs = sort_dir(base_dir, sort_context) 
        
        #保证排序结果合法
        max_retry = 5
        while not(equal_to_list(dir_subs, lsdir(base_dir))) and max_retry>0:
            logger.error(f"invalid sorted directory: {dir_subs}, retry...")
            dir_subs = sort_dir(base_dir, sort_context)
            logger.info(f"newly sorted dir: {dir_subs}")     
            dir_subs = refine_subnames(base_dir, dir_subs)
            logger.info(f"refined sorted dir: {dir_subs}")
            max_retry -= 1

        if not(equal_to_list(dir_subs, lsdir(base_dir))): 
            logger.error(f"sorting failed: {base_dir}")
            return
    
        logger.info(f"final sorted dir: {dir_subs}")  
        for sub in dir_subs:
            abs_path = os.path.abspath(os.path.join(base_dir, sub))
            print(abs_path)
            if not os.path.exists(abs_path):
                logger.error(f"非法路径：{abs_path}")
                continue
                
            if os.path.isfile(abs_path) and abs_path.endswith(filetype):
                try:
                    logger.info(abs_path)
                    logger.info(f"ref_docs: { [os.path.basename(docname) for docname in docs_summary.keys()] }")
                    single_doc_summary = summarize_with_ref(abs_path, vecDB_path, max_context)
                    print(single_doc_summary)
                    docs_summary[abs_path] = single_doc_summary
                    vector_db.index_text(single_doc_summary, f"title：{os.path.basename(abs_path)}， content : {single_doc_summary}", embed_model, vecDB_path)
                    accessed_files += 1
                    print(f"{accessed_files}/{total_files}")
                except:
                    logger.error(f"error processing {abs_path}")
                    print(f"error processing {abs_path}")
                    continue
            elif os.path.isdir(abs_path):
                do_summarize(abs_path)  
        return
        
    do_summarize(KB_dir)
    save_dict_to_file(docs_summary, save_dir)
    #os.remove(vecDB_path)
    print("done")
    return docs_summary
    

summarize_KB_docs(DATA_PATH, f"{OUTPUT_PATH}/docs_summary.json", 20) 

In [None]:
#tags: {tag:weight}(weighted) or [tag,](not weighted)
def kmeans_clust(tags_with_weight, weighted=False, visualize=False):
    def find_origin_tag(embed, tags, embeddings):
        cosine_similarities = cosine_similarity([embed], embeddings)[0] 
        #print(embed.shape, embeddings.shape, cosine_similarities.shape)
        most_similar_index = np.argmax(cosine_similarities)
        origin_tag = tags[most_similar_index]
        return origin_tag
        
    if weighted:
        tags = list(tags_with_weight.keys())
        #weights = log_norm(list(tags_with_weight.values()))
        weights = list(tags_with_weight.values())
        #print(weights)
        
    else:
        tags = tags_with_weight
        weights = None
        
    embeddings = embed_model.encode(tags)
    
    silhouette_scores = []
    distinct_tags_num = len(set(tags))
    max_clusters = min(100, distinct_tags_num)  

    for n_clusters in range(2, max_clusters - 1):
        kmeans = KMeans(n_clusters=n_clusters, random_state=42)
        kmeans.fit(embeddings, sample_weight=weights)
        
        if len(set(kmeans.labels_)) > 1:
            score = silhouette_score(embeddings, kmeans.labels_)
            silhouette_scores.append(score)
        else:
            silhouette_scores.append(-1)  
    #print(silhouette_scores)
    best_n_clusters = np.argmax(silhouette_scores) + 2 
    best_score = silhouette_scores[best_n_clusters - 2]
    
    print(f"best cluster number: {best_n_clusters} (Silhouette Coefficient: {best_score:.3f})")

    #best_n_clusters = 4
    kmeans_best = KMeans(n_clusters=best_n_clusters, random_state=42)
    y_kmeans = kmeans_best.fit_predict(embeddings, sample_weight=weights)
    
    cluster_centers = kmeans_best.cluster_centers_

    clusters_info = []
    
    for i in range(best_n_clusters):
        indices = np.where(kmeans_best.labels_ == i)[0]
        
        cluster_vectors = embeddings[indices]
        cluster_tags = []
        for vec in cluster_vectors:
            cluster_tags.append(find_origin_tag(vec, tags, embeddings))
        
        center_vector = kmeans_best.cluster_centers_[i]
        center_tag = find_origin_tag(center_vector, tags, embeddings)
        
        clusters_info.append({
            'vectors': cluster_vectors,
            "all_tags": cluster_tags,
            'center_vector': center_vector,
            "center_tag": center_tag
        })

    if visualize:
        pca = PCA(n_components=2) 
        embeddings_pca = pca.fit_transform(embeddings)
        
        plt.figure(figsize=(10, 6))
        
        scatter = plt.scatter(embeddings_pca[:, 0], embeddings_pca[:, 1], 
                              c=y_kmeans, s=100, cmap='viridis', alpha=0.7, label='Cluster Labels')
        
        centers_pca = pca.transform(kmeans.cluster_centers_)  # 将聚类中心也降维
        #plt.scatter(centers_pca[:, 0], centers_pca[:, 1], c='red', s=200, alpha=0.75, marker='X', label='Cluster Centers')
        

        cbar = plt.colorbar(scatter)
        cbar.set_label('Cluster Label')
        
        plt.title('KMeans Clustering Visualization of Embeddings')
        plt.xlabel('Principal Component 1')
        plt.ylabel('Principal Component 2')
        plt.legend()
        plt.grid()
        plt.show()
    
    return clusters_info

def filter_tags(input_list, min_times=2):
    #合并相同元素
    #input_list = [lemmatizer.lemmatize(word.lower(), pos='n')  for word in input_list]
    count = Counter(input_list)
    merged_list = [item for item in count if count[item] >= min_times]
    print(f"tags number:{len(merged_list)}")
    return merged_list

def split_tags(merged_list, group_size=10):
    return [merged_list[i:i + group_size] for i in range(0, len(merged_list), group_size)]
    
#docs_sum: {doc_path:summary, }
#return: [entity_tag, ]
def extract_entity_tags(doc_path, docs_sum):
    doc_path = os.path.abspath(doc_path)
    doc_summary = docs_sum[doc_path]
    with open(doc_path, "r") as f:
        doc_content = f.read()
        #doc_content = "\n".join(line for line in doc_content.splitlines() if line.strip() != "")
    doc_title = os.path.basename(doc_path)
    doc_info = (doc_title, doc_summary, doc_content)
    context  = []

    @parse_model_resp(resp_type="list")
    def parse_list(prompt, context):
        return request_model(prompt, context)

    @parse_model_resp(resp_type="str")
    def parse_str(prompt, context):
        return request_model(prompt, context)
        
        
    entity_extraction_prompt = f"""
        ----------------Task Requirements-------------------
        Next, a text and its summary will be provided. You need to refer to the summary and the content of the text to classify meaningful entities within the text. Note that you only need to return the type names of the summary, not the specific entity names, and the number of entity types should be kept to a minimum. Each entity type should be in singular form and begin with an uppercase letter.
        
        ---------------Output Format Requirements--------------------
        Return a list containing all the class names in the following format, with the list being wrapped in ### markers and each type being a string marked with single quotes, formatted as follows:
        ###
        [Return type list]
        ###
        ---------------Current Input--------------------
        Current input text summary:
        《{doc_info[0]}》 : {doc_info[1]}
        
        Current input text content:
        {doc_info[2]}
        """
    
    logger.info(entity_extraction_prompt)
    
    max_try = 10
    first_try = True 
    while max_try>0:
        try:
            if first_try:
                first_try = False 
                entity_types = parse_list(entity_extraction_prompt, context)
            else: entity_types = parse_list(format_error_prompt, context)
            assert isinstance(entity_types, list)
            return list(set(entity_types))
        except Exception as err:
            err_msg = f"err: {err}"
            logger.error(err_msg)
            format_error_prompt = f"""
                The returned object does not meet the format requirements, an issue was encountered while parsing the return content: {err_msg}.
                
                Please check and correct the previous return result. The requirements are as follows:
                
                The return format must be a list object marked with ### at both the beginning and end, where each element is a string marked with single quotes, and no quotes are allowed inside the element strings.
                The entity types should be meaningful and concise.
                Continue to modify the format of the returned object so that it satisfies the above format requirements, i.e., a list containing all the class names marked with ### at both the beginning and end.
                
                Note that you only need to provide the corrected result, without the process or code.               
                """
            print(f"error: {err_msg}, retry")
            max_try -= 1 

    print("exceed max retry, return []")
    return []
            

#return: {doc_path:[entity_type, ], }
def extract_KB_entity_tags(KB_dir, docs_sum):
    KB_schema = {}
    
    total_files = 0
    accessed_files = 0
    for dirpath, dirnames, filenames in os.walk(KB_dir):
        dirnames[:] = [d for d in dirnames if not d.startswith('.')]
        total_files += len(filenames)
        
    for root, dirs, files in os.walk(KB_dir):
        dirs[:] = [d for d in dirs if not d.startswith('.')]
        for file in files:
            try:
                doc_path = os.path.abspath(os.path.join(root, file))
                print(doc_path)
                doc_entity_tags = extract_entity_tags(doc_path, docs_sum)
                KB_schema[doc_path] = doc_entity_tags
                accessed_files += 1
                print(f"{accessed_files}/{total_files}")
            except Exception as err:
                logger.error(f"error processing {doc_path}")
                print(err)
                continue
    return KB_schema

#tag_schema: json {path:[tag, ], }
#return:[{tag:definition， }, ]
def comment_KB_entity_tag(tag_schema, ref_summary_path, ref_num=20, min_times=5):
    @parse_model_resp(resp_type="str")
    def ask_comment(prompt):
        return request_model(prompt)

    @parse_model_resp(resp_type="dict")
    def refine_tags(context):
        refine_prompt = """
            Check and correct the previous return result. The requirements are as follows:
            
            Pay attention to the background knowledge and consider the meanings represented by various types in the context of the provided background knowledge to improve the specificity of the descriptions and avoid being overly broad.
            Merge entities of the same meaning or with a clear inclusion relationship. For entities A and B that have a clear inclusion relationship, where A includes B, merge them into the broader type A. For example, "server" and "http server" should be merged into "server."
            The return format must be a JSON object marked with ### at the beginning and end, where both keys and values are enclosed in double quotes. Each key is a type name, and the value is the definition of that type. However, there should be no extra quotes inside the key and value strings, or around the JSON object itself.
            """
        return request_model(refine_prompt, context)
        
    full_entity_schema = []    
    tags = []
    processed_clusters = 0
    for doc in tag_schema:
        entity_tags = tag_schema[doc]
        for tag in entity_tags:
            tags.append(tag)
            
    tags = filter_tags(tags, min_times)  
    print(tags)

    clusters_info = kmeans_clust(tags, False, False)
    for cluster in clusters_info:
        print(f"{processed_clusters}/{len(clusters_info)}")
        processed_clusters += 1
        cluster_tag_groups = split_tags(cluster["all_tags"], 3)
        print(cluster_tag_groups)
        for cluster_tags in cluster_tag_groups:
            ref_sums = vector_db.query(str(cluster_tags), ref_num, embed_model, ref_summary_path)
            ref_context = ""
            for i, ref in enumerate(ref_sums):
                #print(ref_path)
                ref_context += f"{i+1}. {ref} \n"
    
            
            extraction_prompt = f"""
                ---------------------------------Task Requirements----------------------------------
                There is a series of entity types. First, correct any spelling errors, then replace types that have the same meaning or have a clear inclusion relationship with one common term, and briefly define each merged type in one sentence.
                When defining, make sure to incorporate background knowledge as much as possible to improve the specificity of the descriptions and avoid being overly broad.
                For entity types A and B that have a clear inclusion relationship where A includes B, merge them into the broader type A. For example, "server" and "http server" should be merged into "server."
                The input entity types are as follows: {cluster_tags}
                
                -----------------------------------Output Format Requirements------------------------------------
                
                Return a JSON object containing all merged class definitions in the following format, where both the key and the value are marked with single quotes. Special attention is needed: the key and value strings must not contain any quotes internally. The JSON object should be marked with ### at the beginning and end, and should follow the format below, but keep in mind to only reference the format of the example, and do not treat the example as input:
                ###
                {{"entity_type": "description and definition", }}
                ###

                -----------------------------------Example input------------------------------------
                ['Receiver', 'Header', 'Message', 'Series', 'Protocol'] 

                -----------------------------------Example Output------------------------------------
                ###{{'Receiver': 'A component in the Prometheus ecosystem that accepts and processes incoming metrics or alerts, often part of a monitoring or alerting system.', 'Header': 'A metadata section in a data packet or message that contains control information, used to manage the transmission and processing of metrics or alerts in Prometheus.', 'Message': 'A unit of communication in Prometheus that contains metrics, alerts, or other data, transmitted between components such as exporters, servers, and AlertManager.', 'Series': 'A sequence of timestamped data points in Prometheus, representing a time series used for monitoring and analysis, often associated with a metric and labeled dimensions.', 'Protocol': 'A set of rules and conventions in Prometheus that govern the format and transmission of data, such as the exposition format or remote-write protocol, ensuring reliable and efficient communication between components.'}}###
                
                -----------------------------------Background Knowledge-----------------------------------
                
                These entity types are part of the same system, for which there is documentation that may include domain introductions, terminology definitions, operational methods, and other content. Please merge and briefly define and describe each type using all available documentation:
                {ref_context}
            """
        
            #print(cluster_tags, "\n")
            
            try:
                origin_resp = ask_comment(extraction_prompt)
                context = [
                    {
                     "role": "user",
                     "content": extraction_prompt
                    },
                    {
                    "role": "assistant",
                    "content": origin_resp
                }]
                refined_resp = refine_tags(context)
                full_entity_schema.append(refined_resp)
                print(refined_resp)
                
            except Exception as err:
                print(err)
                break
    return full_entity_schema

In [None]:
time.sleep(3)
docs_sum = load_dict_from_file(f"{OUTPUT_PATH}/docs_summary.json")
entity_tags = extract_KB_entity_tags(DATA_PATH, docs_sum)
save_dict_to_file(entity_tags, f"{OUTPUT_PATH}/entity_tags_bydoc.json")
len(load_dict_from_file(f"{OUTPUT_PATH}/entity_tags_bydoc.json"))

In [None]:
entity_tags = load_dict_from_file(f"{OUTPUT_PATH}/entity_tags_bydoc.json")
full_entity_schema = comment_KB_entity_tag(entity_tags, f"{OUTPUT_PATH}/summary_faiss_index.bin", min_times=5)
save_dict_to_file(full_entity_schema, f"{OUTPUT_PATH}/full_entity_schema.json")

In [None]:
###entity extraction


#return: {tag:[entities]}
def extract_entity_full_schema(entity_schema, doc_path, docs_sum):
    with open(doc_path, "r") as f:
        doc_content = f.read().replace('"', " ").replace('“', " ").replace('”', " ")

    doc_path = os.path.abspath(doc_path)
    doc_summary = docs_sum[doc_path].replace('"', " ").replace('“', " ").replace('”', " ")
    with open(doc_path, "r") as f:
        doc_content = "\n".join(line for line in doc_content.splitlines() if line.strip() != "")
        
    doc_title = os.path.basename(doc_path)
    #doc_info = (doc_title, doc_summary, doc_content)
    
    @parse_model_resp(resp_type="original_str")
    def refine_entities_1(prompt, context):
        return request_model(prompt, context)

    @parse_model_resp(resp_type="dict")
    def refine_entities_2(prompt, context):
        return request_model(prompt, context)

    extraction_prompt = f"""
        Here is a document. Please extract meaningful entities from the document based on their types in conjunction with the specific content, and explain the reasons.
        
        Requirements:
        
        1.Pay attention to reading the document summary and consider how the summary encapsulates the entities within the document.        
        2.Entities should be as specific as possible, avoiding being too abstract or broad. And sufficiently rich in number.
        3.Ensure the extracted entities do exist in the provided text(Detailed Content) but not Entity Type Description. 
        4.Extract entity All entities should be in singular form and capitalized.
        5.Return in JSON format, where both keys and values are enclosed in double quotes. Each key is a type name, and the value contains entities of that type, separated by semicolons (;). However, double quotes are not allowed within the key and value strings.
        6.Ensure that the extracted entities are consistent with the original text. If the original text is Chinese, the entities are expressed in Chinese. If the original text is in English, the entity is in English.
        
        Enclose the returned JSON object with ### markers at the beginning and end, following the format below. Note that you should only refer to the example format and not consider the example as input:
        
           ###
           {{"EntityType1": "Entity1;Entity2;",  }}
           ###

        Document Summary:
        {{
        {doc_summary}
        }}  
        
        Detailed Content:
        {{
        {doc_content}
        }}  
        
        --------Entity Type Description-------------
        The following are all possible types of entities and their descriptions. Please only extract entities that can be categorized into the following types. If there are none, return an empty dictionary. Note that you can only extract entities from the above document, not from the entity type descriptions.
        {{
        {entity_schema}
        }}  
    """
    
    context = []
    
    origin_resp = request_model(extraction_prompt, context)

    refine_prompt_1 = """
        Check and correct the previous classification results. The requirements are as follows:

        1.Pay attention to reading the document summary and consider how the summary encapsulates the entities in the document.
        2.Ensure the extracted entities do exist in the provided text but not from Entity Type Description. Delete the unexisted ones.
        3.Check if there are any missing entities. If so, add the missing entities to the returned JSON object.
        4.All entities should be in singular form and capitalized.
        5.The return format must be a JSON object marked by ### at the beginning and end, with both keys and values marked with double quotes. Each key is a type name, and the value consists of entities under that type, separated by semicolons (;) within the same type and commas (,) between different types. Note that double quotes and parentheses are not allowed within key or value strings. 
        6.Ensure that the extracted entities are consistent with the original text. If the original text is Chinese, the entities are expressed in Chinese. If the original text is in English, the entity is in English.
        """

    refined_resp1 = refine_entities_1(refine_prompt_1, context)
    print(refined_resp1)

    format_refine_context = []
    max_try = 10
    res = {}
    while max_try>0:
        try:
            doc_entities = ast.literal_eval(refined_resp1)
            for tag, entities in doc_entities.items():
                if entities.strip() == "" : continue
                entities_list = (entities.replace("；", ";")).split(";")
                #英语数据加上
                entities_list = [item for item in entities_list if (item in doc_content or item.lower() in doc_content) ]
                if not entities_list: continue
                print(f"{tag}: {entities_list}")
                res[tag] = entities_list
            return res
        except Exception as err:
            err_msg = f"err: {err}"
            logger.error(err_msg)
            format_error_prompt = f"""
                The following object is parsed from the string: {refined_resp1}
                The parsing reveals that the object does not meet the formatting requirements. The return format must be a JSON object marked with ### at the beginning and end, where both keys and values are enclosed in double quotes. Each key is a type name, and the value represents the entities under that type, with different entities separated by semicolons (;) and different types separated by commas (,). However, double quotes and parentheses are not allowed within the key and value strings.
                Continue to modify the format of the object to meet the aforementioned formatting requirements.
                Note that you only need to provide the corrected result, without any process or code.    
                """
            refined_resp1 = refine_entities_1(format_error_prompt, format_refine_context)
            print(f"error: {err_msg}, retry")
            max_try -= 1 

    print("exceed max retry, return {}")
    return {}

#entity_schema: [{tag: comment,}, ]
#return: {doc_path:{tag:[entity,], }, }
def extract_KB_entities(entity_schema_path, docs_sum_path):
    docs_sum = load_dict_from_file(docs_sum_path)
    entity_schema = load_dict_from_file(entity_schema_path)
    #print(entity_schema)
    
    doc_num = len(docs_sum)
    processed_doc = 0
    
    all_entities = {}

    for doc_path in docs_sum.keys():
        all_entities[doc_path] = {}
        try:
            doc_entities = extract_entity_full_schema(entity_schema, doc_path, docs_sum)
            all_entities[doc_path] = doc_entities
            print(f"{processed_doc+1}/{doc_num}: {doc_path} succeed")
        except Exception as err:
            print(f"{processed_doc+1}/{doc_num}: {err} processing {doc_path}")
            continue
            
        processed_doc += 1
        #break
    return all_entities  
    


In [None]:
all_entities = extract_KB_entities(f"{OUTPUT_PATH}/full_entity_schema.json", f"{OUTPUT_PATH}/docs_summary.json")
save_dict_to_file(all_entities, f"{OUTPUT_PATH}/all_entities.json")  
len(load_dict_from_file(f"{OUTPUT_PATH}/all_entities.json"))

In [None]:
###triple extraction

def split_text(content, chunk_size, lang="eng"):
    # sunfire, prometheus按行分割
    chunks = content.strip().split(".")
    chunks = [chunk.strip() for chunk in chunks]
    
    output_chunks = []
    current_chunk = []
    current_tokens = 0

    for part in chunks:
        if part=="": continue
        if lang=="eng":
            tokens = word_tokenize(part)
        elif lang=="chn":
            tokens = jieba.lcut(part) 
        else: return
        token_count = len(tokens)
        
        if current_tokens + token_count <= chunk_size:
            current_chunk.append(part.strip()) 
            current_tokens += token_count
        else:
            if current_chunk:
                output_chunks.append('\n'.join(current_chunk))
            current_chunk = [part.strip()]
            current_tokens = token_count

    if current_chunk:
        output_chunks.append('\n'.join(current_chunk))
    return output_chunks

#entity_list: {tag:[entiies], }
#return: [entiies, ]
def filter_entity_by_text(entity_list, text, with_tag=False, entity_schema={}):
    result = []
    entity_schema = {k: v for d in entity_schema for k, v in d.items()}

    for key, value_list in entity_list.items():
        found_strings = set()
        
        for string in value_list:
            if string in text or string.lower() in text:
                found_strings.add(string)
        
        if found_strings:
            if with_tag:
                for entity in found_strings:
                    try: result.append((entity, key, entity_schema[key]))
                    except: continue
            else:
                result.extend(list(found_strings))
    return set(result)


#entity_schema:[{entity_type: entities}, ]
#entity_list: (tag:[entiies], )
#return: [(s:s_type, p, o:o_type), ]
def extract_triple(entity_schema, entity_list, doc_path, chunk_size=200):
    entity_tags = list(set(key for item in entity_schema for key in item.keys()))
    with open(doc_path, "r") as f:
        doc_content = f.read().replace('"', " ").replace('“', " ").replace('”', " ")

    doc_path = os.path.abspath(doc_path)

    @parse_model_resp(resp_type="original_str")
    def refine_entities(prompt, context, model="default"):
        if model == "default": return request_model(prompt, context)
        else: return request_model(prompt, context, model=model)

    triples = []
    chunks = split_text(doc_content, chunk_size, lang="chn")
    for idx,text_chunk in enumerate(chunks):   
        print(f"chunk: {idx+1} / {len(chunks)}")
        logger.info(f"{doc_path}, chunk {idx}, \n content: \n{text_chunk}")
        existing_entities = filter_entity_by_text(entity_list, text_chunk, False)
        logger.info(f"existing_entities: {existing_entities}")
        if existing_entities == {}: continue
        extraction_prompt = f"""
            A document is as follows:
            
            {{
            {text_chunk}
            }}
            
            Entities extracted from this document are as follows:
            
            {{
            {existing_entities}
            }}
            
            Based on the original text, what entities do you believe have clear and meaningful relationships? Provide the relationships between them with the following requirements:
            
            1.Use list(tuple) format, where each tuple is represented as a triplet (Subject-Relationship-Object). All entities and relationships should be in singular form with the first letter capitalized.
            2.At least one of the subject or object in the tuple must come from the entities provided above or from the entities given above.
            3.The types of relationships should be diversed, and the number of relationship tuples should be maximized.
            4.The triplets should represent specific, clear, and meaningful relationships and should not include vague relationships like 'is related to'.
            5.The relationship content (predicate) should be short and clear, expressing the semantic relationship in as few words as possible.
                    """
        
        context = []
        
        origin_resp = request_model(extraction_prompt, context)

        refine_prompt_1 = f"""   
            First, remove ambiguous relations like “has”, "is", "related to"
            Then, annotate the types of all entities in all tuples, changing the format to ("Subject Entity Type: Subject Entity", "Relationship Content", "Object Entity Type: Object Entity") based on the original tuples. Each element in the tuple should be marked with double quotes before and after.
            
            Return a list object where each element is a modified tuple, with the entity types before the entities in both the subject and the object. All entities and relationships in the tuples should be in singular form, with the first letter capitalized.
            
            All entity types are as follows; please choose the most appropriate one from the list below and annotate it within the existing entities:
            {{
            {entity_tags}
            }}
        """
    
        refined_resp1 = request_model(refine_prompt_1, context)
        
        #format_refine_context = []
        #规范格式
        format_refine_prompt= f"""
            For the list object just mentioned, where each element is a tuple representing a relationship, the following conditions need to be checked and corrected, and then return the corrected result:
            
            1.For entities missing type annotations, label their types as required above, while keeping the subject-verb-object order unchanged.
            2.Translate the Chinese in the tuples into English and correct any misspellings in the English words. All entities and relationships in the tuples should be in singular form, and their initial letters should be capitalized.
            3.The tuple format has three elements。Both subject and object entities are prefixed with their type and separated by a colon : ("Subject Entity Type: Subject Entity", "Relationship Content", "Object Entity Type: Object Entity").
            4.Each element in the tuple should be marked with double quotes " at both ends, and there should be no other quotes. There should be no quotes within the elements or between the tuples; if there are any quotes within an element, remove them.
            5.Return a list object marked with ### (three # symbols) at both ends, ensuring that the list has exactly one level of square brackets [] and no additional content. The format should be as follows: ###[("Subject Entity Type: Subject Entity", "Relationship Content", "Object Entity Type: Object Entity"),]###. Note that the tuples are the content you need to fill in.
            6.Ensure that the entities are consistent with the original text. If the original text is Chinese, the entities are expressed in Chinese. If the original text is in English, the entity is in English.
                    """
        if origin_resp is None or refined_resp1 is None: continue

        try:
            refined_resp2 = refine_entities(format_refine_prompt, context)
        except: continue
            
        #print(refined_resp2)
        max_try = 3
        while max_try>0:
            try:
                refined_resp2 = refined_resp2.replace("）",")").replace("（", "(").replace("：", ":").replace("，", ",")
                ret = ast.literal_eval(refined_resp2)
                assert isinstance(ret, list)
                triples.extend(ret)
                print(ret)
                break
            except Exception as err:
                max_try -= 1 
                err_msg = f"err: {err} at {doc_path} chunk {idx}"
                logger.error(err_msg)
                format_error_prompt = f"""
                    From the returned content above, parsed out the following object: {refined_resp2}
                    An error occurred while using ast.literal_eval to parse the returned object: {err_msg}
                    
                    Continue to modify the format of the object so that it meets the requirements of the above format and can be correctly parsed by ast.literal_eval.
                    Note that you only need to provide the corrected result and mark it with three # symbols formatted like ###reutrned object###.
                    
                    Typical errors that may occur:
                    1.Pay attention to whether each element in the tuple has double quotes; if so, they must be removed.
                    2.Instead of using colons to separate entities and types, bracketed tags are incorrectly used
                    3.Each tuple can only have 3 elements. The entity type and entity must be in the same element of the tuple (i.e., within the same string) and separated by a colon. It must conform to the format ("subject entity type:subject entity", "relationship content", "object entity type:object entity"). If the type and entity are found in different strings, they need to be combined into one.   
                    4.The returned object may not be marked with ###
                    5.Entities may miss type annotations, you need to label their types as required above, while keeping the subject-verb-object order unchanged."""
                
                try:
                    refined_resp2 = refine_entities(format_error_prompt, context)
                except:
                    print("resp parsing error")
                    continue
                print(f"error: {err_msg}, retry")
                if max_try==0: print("exceed max retry, continue")

    logger.info(f"triples:\n {triples}")    

    return triples

#all_entities: {doc_path: {entity_tag: [entitiy, ], }, } 
def extract_KB_triples(entity_list_path, entity_schema_path, chunk_size=200):
    all_entities = load_dict_from_file(entity_list_path)
    entity_schema = load_dict_from_file(entity_schema_path)
    #print(all_entities)
    
    doc_num = len(all_entities)
    processed_doc = 0
    
    all_triples = {}

    for doc_path in all_entities.keys():
        print(f"{processed_doc+1}/{doc_num}: {doc_path}")
        processed_doc += 1

        entity_list = all_entities[doc_path]
        if entity_list == {} : continue
        doc_triples = extract_triple(entity_schema, entity_list, doc_path, chunk_size=chunk_size)
        all_triples[doc_path] = doc_triples
        #break
    return all_triples  

def post_process(triples_path, entity_schema_path):
    triples = load_dict_from_file(triples_path)
    entity_schema = load_dict_from_file(entity_schema_path)
    entity_tags = list(set(key for item in entity_schema for key in item.keys()))

    output_json = {}

    for key, value in triples.items():
        output_list = []
        try:
            for tup in value:
                if len(tup) == 3 and ":" in tup[0] and ":" in tup[2]:
                    s, p, o = tup
                    s_tag, s_entity = s.split(':', 1)
                    o_tag, o_entity = o.split(':', 1)
    
                    s_tag = find_best_match(s_tag, entity_tags)
                    o_tag = find_best_match(o_tag, entity_tags)
                    
                    output_dict = {
                        'subject_tag': s_tag,
                        'subject_entity': s_entity.strip(),
                        'predicate': p.strip(),
                        'object_tag': o_tag,
                        'object_entity': o_entity.strip()
                    }
                    output_list.append(output_dict)
        except: continue
        output_json[key] = output_list
    save_dict_to_file(output_json, triples_path)
    return output_json

In [None]:
CHUNK_SIZE = 200
all_triples = extract_KB_triples(f"{OUTPUT_PATH}/all_entities.json", f"{OUTPUT_PATH}/full_entity_schema.json", chunk_size=CHUNK_SIZE)
save_dict_to_file(all_triples, f"{OUTPUT_PATH}/all_triples.json")
post_process(f"{OUTPUT_PATH}/all_triples.json", f"{OUTPUT_PATH}/full_entity_schema.json")