In [1]:
import os
import json
import time
import pandas as pd
from zhipuai import ZhipuAI
import requests

# Initialize ZhipuAI client
client = ZhipuAI(api_key="_____")  # Replace with your actual API key

# Folder paths and files
disease_folder = "disease introduction/filtered"  # Folder containing filtered files
output_path = "KG/"  # Output folder for processed files

# Ensure output directory exists
os.makedirs(output_path, exist_ok=True)

# Function to generate the OIE prompt
def get_oie_prompt(disease_name, disease_intro):
    return f"""
    Given the disease name as '{disease_name}', extract the relevant entities and relationships from the following disease introduction:
    {disease_intro}.
    
    Instructions:
    1. The **symptoms**, **emergency examination**, **precautions**, **prognosis**, and other such terms should be treated as **relations**, not as subjects or objects.
    2. Extract **entities** as specific **noun phrases** or **proper nouns**, avoiding full sentences or long descriptions. For example, for the symptom 'fever and headache', use 'fever' and 'headache' as entities, rather than the full sentence.
    3. Do **not combine** multiple pieces of information into a single triplet. If a subject and object contain multiple related pieces of information, they should be split into separate triplets. For example, if a symptom includes multiple manifestations (e.g., 'fatigue, chest pain'), each manifestation should be captured as an individual triplet.
    4. Ensure that the knowledge graph is **connected**, meaning that the subject and object are linked through meaningful relationships, and that no entities or relations are isolated or fragmented.
    5. Avoid confusing relations and entities. Relations should represent connections between entities (e.g., 'causes', 'leads to', 'is treated with'), whereas entities should represent specific items such as disease names, symptoms, treatments, etc.
    6. When extracting relationships, ensure clarity and avoid ambiguity. If the sentence mentions multiple relationships or entities, separate them clearly into distinct triplets.

    Return the extracted relationships as a list of triplets in the form of [Subject, Relation, Object].
    Example: ['Thyroid cancer', 'isTypeOf', 'Malignant tumor'].
    """


# Function to generate the relation definition prompt
def get_relation_definition_prompt(disease_name, KG, relations):
    return f"""
    Given the extracted relationships from the disease '{disease_name}', define each relation in clear and natural language.
    The original knowledge graph (KG) is shown below: '{KG}'.
    
    Instructions:
    1. Your definitions should be **general** and **applicable** across different diseases, ensuring that they can be used in multiple disease knowledge graphs.
    2. Do not include any **disease-specific information** (e.g., symptoms, treatments, or specific diseases) in the relationship definitions.
    3. The definitions should be **simple, precise, and concise**, focusing on the relationship itself without detailing specific examples.
    
    Example:
    - 'isTypeOf': 'This relation indicates that the subject is a category or classification of the object entity, typically categorizing an entity in terms of a broader class.'
    
    Now, please provide definitions for the following relations:
    {relations}
    """


# Function to generate the canonicalization prompt
def get_canonicalization_prompt(disease_name, relations, definitions):
    return f"""
    Given the relationships for disease '{disease_name}', standardize the following relations to ensure consistency across diseases:
    {relations}
    Use the definitions provided below to ensure the relationships are standardized:
    {definitions}
    """

# Helper function to upload JSONL to API and start batch job
def upload_jsonl_and_create_batch(jsonl_path):
    with open(jsonl_path, "rb") as f:
        result = client.files.create(file=f, purpose="batch")
        print(f"File uploaded successfully. File ID: {result.id}")
        batch_job = client.batches.create(
            input_file_id=result.id,
            endpoint="/v4/chat/completions",
            auto_delete_input_file=True,
            metadata={"description": "Disease Knowledge Graph Construction"}
        )
        print(f"Batch job created successfully. Batch ID: {batch_job.id}")
        return batch_job.id

# Poll for batch processing status
def check_batch_status(batch_id):
    while True:
        batch_status = client.batches.retrieve(batch_id)
        print(f"Current status: {batch_status.status}")
        if batch_status.status == "completed":
            print("Batch processing completed. Downloading results.")
            return True
        elif batch_status.status in ["failed", "expired", "cancelled"]:
            print("Batch processing failed or cancelled.")
            return False
        time.sleep(30)

# Download the batch processing results
def download_results(batch_id, stage):
    batch_job = client.batches.retrieve(batch_id)
    result_file_id = batch_job.output_file_id
    content = client.files.content(result_file_id)
    result_file = os.path.join(output_path, f"batch_output_stage_{stage}.jsonl")
    content.write_to_file(result_file)
    print(f"Results downloaded and saved to {result_file}")
    return result_file

# Process the results from the API and save
def process_results(result_file, stage):
    df = pd.read_json(result_file, lines=True)
    df[f'KG_{stage}'] = df["response"].apply(lambda x: x['body']['choices'][0]['message']['content'].strip())
    output_file = os.path.join(output_path, f'processed_results_{stage}.json')
    df.to_json(output_file, orient="records", lines=True, force_ascii=False)
    return df

# Helper function to read all disease files and gather text data
def load_disease_data_from_files():
    disease_data = {}
    for filename in os.listdir(disease_folder):
        if filename.endswith(".txt"):  # Process only txt files
            disease_name = filename.replace('.txt', '')
            file_path = os.path.join(disease_folder, filename)
            with open(file_path, 'r', encoding='utf-8') as file:
                disease_intro = file.read()
                disease_data[disease_name] = disease_intro
    return disease_data


def process_kg_1_relation(kg_1_str):
    # Split the KG_1 string by each line (using '\n' to split the data)
    kg_1_lines = kg_1_str.strip().split('\n')
    kg_1_relations = []
    
    for line in kg_1_lines:
        # Remove the line number and square brackets
        line = line.strip()
        line = line.split("[")[1] if "[" in line else line  # Remove the numeric prefix (e.g., "37. ")
#         print(line)
#         line = line[1]
        parts = line.strip().strip('[]').split("', '")
        
        if len(parts) == 3:
            # Remove single quotes and store the triplet
            relation = parts[1].strip("'")
            kg_1_relations.append(relation)
    kg_1_relations = set(kg_1_relations)
    return kg_1_relations


# Helper function to process each stage
def process_stage(disease_data, stage, previous_stage_results=None, force_run=False):
    all_requests = []
    for disease_name, disease_intro in disease_data.items():
        custom_id = disease_name if len(disease_name) >= 6 else disease_name.ljust(6, "_")

        # Stage 1: Information extraction (OIE)
        if stage == 1:
            prompt = get_oie_prompt(disease_name, disease_intro)
#             print(f"Stage 1 Prompt: {prompt}")  # Debug log for Stage 1 prompt

        # Stage 2: Relation definition (using Stage 1 results)
        elif stage == 2:
            if not force_run and previous_stage_results is not None:
                # Use Stage 1 output (relations) to define relations
                KG_ONE = previous_stage_results.loc[custom_id]["KG_1"]
                relations = process_kg_1_relation(KG_ONE)
                
                prompt = get_relation_definition_prompt(disease_name, KG_ONE, relations)
#                 print(f"Stage 2 Prompt: {prompt}")  # Debug log for Stage 2 prompt
            else:
                prompt = get_relation_definition_prompt(disease_name, [], [])

        # Stage 3: Canonicalization (using Stage 1 and Stage 2 outputs)
        elif stage == 3:
            if not force_run and previous_stage_results is not None:
                # Use Stage 2 output (definitions) and Stage 1 relations for standardization
                relations = previous_stage_results.loc[custom_id]["KG_1"]
                definitions = previous_stage_results.loc[custom_id]["KG_2"]
                prompt = get_canonicalization_prompt(disease_name, relations, definitions)
#                 print(f"Stage 3 Prompt: {prompt}")  # Debug log for Stage 3 prompt
            else:
                prompt = get_canonicalization_prompt(disease_name, [], [])

        request = {
            "custom_id": custom_id,
            "method": "POST",
            "url": "/v4/chat/completions",
            "body": {
                "model": "glm-4-flash",  # Adjust this model if necessary
                "messages": [
                    {"role": "system", "content": f"Process disease information for '{disease_name}'."},
                    {"role": "user", "content": prompt}
                ]
            }
        }
        all_requests.append(json.dumps(request, ensure_ascii=False))

    jsonl_path = f"KG/batch_requests_KG_stage_{stage}.jsonl"  # Path to save the JSONL file
    with open(jsonl_path, "w", encoding="utf-8") as f:
        f.write("\n".join(all_requests))

    batch_id = upload_jsonl_and_create_batch(jsonl_path)
    if check_batch_status(batch_id):
        result_file = download_results(batch_id, stage)
        return process_results(result_file, stage)
    
    return None


# Main entry point for running the entire process
def main(force_run_all=False):
    # Load disease data from filtered folder
    disease_data = load_disease_data_from_files()

    # Stage 1: Information extraction (OIE)
    stage_1_file = os.path.join(output_path, "processed_results_1.json")
    if os.path.exists(stage_1_file) and not force_run_all:
        print("Stage 1 results already exist. Loading them...")
        output_file = os.path.join(output_path, f'processed_results_1.json')
        stage_1_results = pd.read_json(output_file, lines=True)
        stage_1_results = stage_1_results.set_index('custom_id')
#         print(stage_1_results.loc['bone tumor']['KG_1'])
#         stage_1_results = process_results(stage_1_file, 1)
    else:
        print("Stage 1 results do not exist or forced to run from scratch. Running Stage 1...")
        stage_1_results = process_stage(disease_data, 1, force_run=force_run_all)

    # Stage 2: Relation definition, using the output from Stage 1
    stage_2_file = os.path.join(output_path, "processed_results_2.json")
    if os.path.exists(stage_2_file) and not force_run_all:
        print("Stage 2 results already exist. Loading them...")
#         stage_2_results = pd.read_json(stage_2_file, lines=True)
        stage_2_results = process_results(stage_2_file, 2)
    else:
        print("Stage 2 results do not exist or forced to run from scratch. Running Stage 2...")
        stage_2_results = process_stage(disease_data, 2, previous_stage_results=stage_1_results, force_run=force_run_all)

#     # Stage 3: Canonicalization, using the output from Stage 2
#     stage_3_file = os.path.join(output_path, "processed_results_3.json")
#     if os.path.exists(stage_3_file) and not force_run_all:
#         print("Stage 3 results already exist. Loading them...")
#         stage_3_results = pd.read_json(stage_3_file, lines=True)
#     else:
#         print("Stage 3 results do not exist or forced to run from scratch. Running Stage 3...")
#         stage_3_results = process_stage(disease_data, 3, previous_stage_results=stage_2_results, force_run=force_run_all)

if __name__ == "__main__":
    main(force_run_all=False)  # Set to True to force re-run all stages


Stage 1 results already exist. Loading them...
Stage 2 results do not exist or forced to run from scratch. Running Stage 2...
File uploaded successfully. File ID: 1740673288_7d16501679594b02a9ed2188014d9da2
Batch job created successfully. Batch ID: batch_1895147252189110272
Current status: validating
Current status: validating
Current status: validating
Current status: validating
Current status: validating
Current status: validating
Current status: validating
Current status: validating
Current status: validating
Current status: validating
Current status: validating
Current status: validating
Current status: validating
Current status: validating
Current status: validating
Current status: validating
Current status: validating
Current status: in_progress
Current status: in_progress
Current status: in_progress
Current status: completed
Batch processing completed. Downloading results.
Results downloaded and saved to KG/batch_output_stage_2.jsonl
