In [None]:
import os
import pandas as pd
from tqdm.notebook import tqdm_notebook
from zlai.llms import Zhipu
from zlai.llms.generate_config.api import GLM4FlashGenerateConfig, GLM4GenerateConfig, GLM4AirGenerateConfig
from zlai.agent import GraphAgent, EntityAgent
from zlai.parse import ParseCode

In [None]:
# 选择合适的模型
# llm = Zhipu(generate_config=GLM4FlashGenerateConfig(max_tokens=8192))
llm = Zhipu(generate_config=GLM4GenerateConfig(max_tokens=8192))
# llm = Zhipu(generate_config=GLM4AirGenerateConfig(max_tokens=8192))

In [None]:
# 数据读取与清洗
with open("../data/西游记.md", "r", encoding="utf-8") as f:
    content = f.read()
    content = content.replace("\u3000", "")

data = content.split("\n\n\n\n\n\n")

book = []
for chapter in data:
    paragraph = [item for item in chapter.split("\n") if len(item) > 128]
    book.extend(paragraph)

def batches(lst: list, batch_size: int,):
    """
    desc: 生成批次数据
    :param lst: 原始List
    :param batch_size: 批次大小
    :return:
    """
    for i in range(0, len(lst), batch_size):
        yield lst[i:i+batch_size]

In [None]:
# 创建保存文件与数据字段
entity_file = "./entity_v1.csv"
relation_file = "./relation_v1.csv"

entity_columns = ["entity_name", "entity_type", "entity_description"]
relation_columns = ["source", "target", "relationship", "description", "strength"]

df_entity = pd.DataFrame(columns=entity_columns)
df_relation = pd.DataFrame(columns=relation_columns)

if not os.path.exists(entity_file):
    df_entity.to_csv(entity_file, index=False)
if not os.path.exists(relation_file):
    df_relation.to_csv(relation_file, index=False)

In [None]:
# 每五个段落为一组进行实体关系识别
book_data = ['\n'.join(item) for item in batches(book, 5)]
total = len(book_data)
error_chapter = []
total

In [None]:
# 数据跑批，识别，这里只识别了人物地点
for i, chapter in tqdm_notebook(enumerate(book_data), total=total):
    try:
        agent = GraphAgent(llm=llm, verbose=False)
        task_completion = agent(chapter, entity_types=str(["人物", "地点"]))
        df_entity = pd.DataFrame(eval(task_completion.data.get("entities")), columns=entity_columns)
        df_relation = pd.DataFrame(eval(ParseCode.sparse_script(task_completion.content)[0]), columns=relation_columns)
        df_entity.to_csv(entity_file, mode='a', header=False, index=False)
        df_relation.to_csv(relation_file, mode='a', header=False, index=False)
    except Exception as e:
        error_chapter.append(i)
        print(f"{i}: {e}")

In [None]:
error_chapter

In [None]:
# 单独对未成功识别的段落进行再次识别
chapter = book_data[127]
agent = GraphAgent(llm=llm, verbose=True)
task_completion = agent(chapter, entity_types=str(["人物", "地点"]))

In [None]:
df_entity = pd.DataFrame(eval(task_completion.data.get("entities")), columns=entity_columns)
df_relation = pd.DataFrame(eval(ParseCode.sparse_script(task_completion.content)[0]), columns=relation_columns)

In [None]:
df_entity.to_csv(entity_file, mode='a', header=False, index=False)
df_relation.to_csv(relation_file, mode='a', header=False, index=False)

-------