In [1]:
from concurrent.futures import ThreadPoolExecutor, as_completed
from functools import partial
import json
from json.decoder import JSONDecodeError
import os

from dotenv import load_dotenv
from icecream import ic
from neo4j import GraphDatabase
import pandas as pd
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModel

from data_processing_helpers import flatten_dict, load_schema, process_raw_data, load_data
from llm_helpers import create_prompt, process_template, sentences2embeddings
from neo4j_helpers import NodeModel, RelationshipModel, create_nodes, create_relationships


load_dotenv()

OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
OPENAI_MODEL = os.getenv("OPENAI_MODEL")
OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL")


# Load Label Data

In [2]:
schema_list = load_schema("data/53_schemas.json")
train_raw_list = load_data("data/train_data.json")
val_raw_list = load_data("data/val_data.json")

train_list, val_list = process_raw_data(train_raw_list), process_raw_data(val_raw_list)
val_sample_list = val_list[:10]
# train_data_path = 'data/train_data.json'

# with open(train_data_path, 'r') as file:
#     train_data = [json.loads(line) for line in file.readlines()]

# with open('data/val_data.json', 'r') as file:
#     val_data = [json.loads(line) for line in file.readlines()]

# train_data_list = [process_raw_data(item) for item in train_data]

# train_list = [flatten_dict(item) for item in train_data]
# train_sample_list = train_list[:10]

# val_list = [flatten_dict(item) for item in val_data]
# val_sample_list = val_list[:10]

# schema_list = load_schema("data/53_schemas.json")

In [3]:
train_list[0]

{'spo': {'subject': '产后抑郁症',
  'subject_type': '疾病',
  'object': '轻度情绪失调',
  'object_type': '疾病',
  'predicate': '鉴别诊断'},
 'text': '产后抑郁症@区分产后抑郁症与轻度情绪失调（产后忧郁或“婴儿忧郁”）是重要的，因为轻度情绪失调不需要治疗。'}

In [4]:
schema_list[:5]

[{'subject_type': '疾病', 'predicate': '预防', 'object_type': '其他'},
 {'subject_type': '疾病', 'predicate': '阶段', 'object_type': '其他'},
 {'subject_type': '疾病', 'predicate': '就诊科室', 'object_type': '其他'},
 {'subject_type': '其他', 'predicate': '同义词（其他/其他）', 'object_type': '其他'},
 {'subject_type': '疾病', 'predicate': '辅助治疗', 'object_type': '其他治疗'}]

In [5]:
val_sample_list

[{'spo': {'subject': '急性胰腺炎',
   'subject_type': '疾病',
   'object': 'ERCP',
   'object_type': '检查',
   'predicate': '影像学检查'},
  'text': '急性胰腺炎@有研究显示，进行早期 ERCP （24 小时内）可以降低梗阻性胆总管结石患者的并发症发生率和死亡率； 但是，对于无胆总管梗阻的胆汁性急性胰腺炎患者，不需要进行早期 ERCP。'},
 {'spo': {'subject': '广泛性焦虑症',
   'subject_type': '疾病',
   'object': '社交性焦虑',
   'object_type': '疾病',
   'predicate': '鉴别诊断'},
  'text': '【诊断】 根据疾病诊断和统计手册第4版标准，焦虑情绪持续6个月以上，并至少下述4项症状： 1.担忧将来的意外事件； 2.担忧自己的能力； 3.担忧过去的行为； 4.躯体不适症状； 5.自我意识（对主体的自我认识）； 6.不断需要得到他人的确认； 7.持续紧张和（或）不能放松； 广泛性焦虑症影响社会交往，与分离性焦虑症比较，更多伴有其他焦虑症，如惊恐发作或单纯性恐怖症。 （二）社交性焦虑 尽管两种疾病均害怕在公众场合下说话，但广泛性焦虑也害怕对过去和将来情形的焦虑。'},
 {'spo': {'subject': '骨性关节炎',
   'subject_type': '疾病',
   'object': '关节',
   'object_type': '部位',
   'predicate': '发病部位'},
  'text': '骨性关节炎@在其他关节（如踝关节和腕关节），骨性关节炎比较少见，并且一般有潜在的病因（如结晶性关节病、创伤）。'},
 {'spo': {'subject': '胆囊穿孔',
   'subject_type': '疾病',
   'object': '30%',
   'object_type': '流行病学',
   'predicate': '死亡率'},
  'text': '胆囊炎@如果胆囊穿孔，死亡率为 30%。'},
 {'spo': {'subject': '乙型肝炎',
   'subje

# Create SPO from Unstructured Text

In [6]:
prompts = create_prompt(train_list, val_sample_list, schema_list, sample_size=10)
process_template_partial = partial(
    process_template, 
    model=OPENAI_MODEL, 
    api_key=OPENAI_API_KEY,
    base_url=OPENAI_BASE_URL
)

result_list = []
error_case_list = []

num_workers = 10

with ThreadPoolExecutor(max_workers=num_workers) as executor:
    future_to_template = {executor.submit(process_template_partial, prompt): prompt for prompt in prompts}
    
    for future in tqdm(as_completed(future_to_template), total=len(future_to_template)):
        template = future_to_template[future]
        try:
            result = future.result()
            if result[0] == 'error':
                error_case_list.append(result[1])
            else:
                result_list.append(result[1])
        except Exception as e:
            error_case_list.append(('Exception', str(e)))

  0%|          | 0/10 [00:00<?, ?it/s]

In [7]:
len(result_list)

5

In [8]:
result_list[0]

{'subject': '胆囊炎',
 'subject_type': '疾病',
 'object': '死亡率',
 'object_type': '流行病学',
 'predicate': '死亡率',
 'text': '胆囊炎@如果胆囊穿孔，死亡率为 30%。'}

# SPO to Knowledge Graph

In [9]:
driver = GraphDatabase.driver("bolt://localhost:7687", auth=("neo4j", "password"))

result_df = pd.DataFrame(result_list)

pattern = r'[^\u4e00-\u9fffA-Za-z\s]'

for col in ['subject', 'object', 'predicate']:
    # Apply the function to the column
    result_df[col] = result_df[col].str.replace(pattern, '', regex=True)
    result_df = result_df.drop(result_df[result_df[col] == ''].index)
    result_df[col] = result_df[col].str.replace(' ', '_')

In [10]:
tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-large-en")
model = AutoModel.from_pretrained("BAAI/bge-large-en")
model.eval()


embedding_func = partial(sentences2embeddings, model=model, tokenizer=tokenizer)  

node_knowledge_train = NodeModel(
    label=("hard_coded_label", "Knowledge"),
    id_prop=('text', 'text'),
    properties={
        "text": ("text_embedding", embedding_func)
    }
)

node_subject_train = NodeModel(
    # label=("column_name_label", "subject"),
    label=("hard_coded_label", "Subject"),
    id_prop=('subject', 'name'),
    properties={
        "subject_type": ("type", None)
    },
    # extra_labels=['Subject', 'Result']
)

node_object_train = NodeModel(
    # label=("column_name_label", "object"),
    label=("hard_coded_label", "Object"),
    id_prop=('object', 'name'),
    properties={
        "object_type": ("type", None)
    },
)

rel_train_knowledge2subject = RelationshipModel(
    source_node='Knowledge',
    target_node='Subject',
    # rel_label=('column_name_label', 'predicate'),
    rel_label=('hard_coded_label', 'HAS_SUBJECT'),
    source_id=('text', 'text'),
    target_id=('subject', 'name'),
    # extra_labels=['Predicate', 'Result']
)

rel_train_subject2object = RelationshipModel(
    source_node='Subject',
    target_node='Object',
    # rel_label=('column_name_label', 'predicate'),
    rel_label=('hard_coded_label', 'PREDICATE'),
    source_id=('subject', 'name'),
    target_id=('object', 'name'),
    properties={
        "predicate": "type"
    }
)


In [11]:
with driver.session() as session:
    session.write_transaction(create_nodes, result_df, [node_knowledge_train, node_subject_train, node_object_train])
    session.write_transaction(create_relationships, result_df, [rel_train_knowledge2subject, rel_train_subject2object])

  session.write_transaction(create_nodes, result_df, [node_knowledge_train, node_subject_train, node_object_train])


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  session.write_transaction(create_relationships, result_df, [rel_train_knowledge2subject, rel_train_subject2object])


In [13]:
from agent import generate_response


message = '胆囊炎死亡率有多高？'

response = generate_response(message)

Parent run b335a461-068b-4af1-8133-125e744e06da not found for run 749fddd8-5463-4f5b-94fc-7f8d22c3457e. Treating as a root run.




[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mThought: Do I need to use a tool? Yes
Action: Medical Knowledge Search
Action: 胆囊炎死亡率
Observation[0mInvalid Format: Missing 'Action Input:' after 'Action:'[32;1m[1;3mThought: Do I need to use a tool? Yes
Action: Medical Knowledge Search
Action Input: 胆囊炎死亡率
Observation[0m[33;1m[1;3m{'input': '胆囊炎死亡率\nObservation', 'context': [Document(page_content='乙型肝炎@## 患者指导 HBsAg阳性的人性交时如果对方未接种疫苗或无自然免疫应采取防护，不应与他人共用牙刷或剃须刀，应遮盖开放的割伤和擦伤，使用漂白剂或洗涤剂清洁溢出血液，不献血、不捐献器官或精液。'), Document(page_content='胆囊炎@如果胆囊穿孔，死亡率为 30%。'), Document(page_content='稳定型缺血性心脏疾病@ * 所有年龄超过 21 岁、LDL 大于或等于 190 的患者都应当接受高强度他汀类药物治疗。'), Document(page_content='骨性关节炎@在其他关节（如踝关节和腕关节），骨性关节炎比较少见，并且一般有潜在的病因（如结晶性关节病、创伤）。')], 'answer': '根据给定的上下文，如果胆囊炎导致胆囊穿孔，死亡率为 30%。'}[0m[32;1m[1;3mThought: Do I need to use a tool? No
Final answer: 如果胆囊炎导致胆囊穿孔，死亡率为30%。[0mInvalid Format: Missing 'Action:' after 'Thought:[32;1m[1;3mThought: Do I need to use a tool? No
Final answer: 如果胆囊炎导致胆囊穿孔，死

In [None]:
print(response)