In [9]:
import requests
import json
from json.decoder import JSONDecodeError
import os
from dotenv import load_dotenv

load_dotenv()

OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
OPENAI_MODEL = os.getenv("OPENAI_MODEL")
OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL")


def parse_json_markdown(md_string):
    start = md_string.find('{')
    # Find the ending position of the JSON part
    end = md_string.rfind('}') + 1

    # Extract the JSON string
    json_string = md_string[start:end]

    # Parse the JSON string into a Python dictionary
    try:
        data = json.loads(json_string)
    except JSONDecodeError as e:
        data = eval(json_string)

    # Print the dictionary
    return data

def load_schema(data_path):
    with open(data_path, "r") as f:
        schema_list = [eval(x) for x in f.readlines()]

    return schema_list

def load_data(data_path):
    with open(data_path, 'r', encoding='utf-8') as file:
        content = file.read()

    json_strings = content.strip().split('\n')
    data_list = [json.loads(json_str) for json_str in json_strings]

    return data_list

def process_raw_data(raw_data_list):
    train_list = []

    for row in raw_data_list:
        record = {"spo": {}}

        record['text'] = row['text']
        record['spo']['subject'] = row["spo_list"][0]["subject"]
        record['spo']['subject_type'] = row["spo_list"][0]["subject_type"]
        record['spo']['object'] = row["spo_list"][0]["object"]
        record['spo']['object_type'] = row["spo_list"][0]["object_type"]
        record['spo']['predicate'] = row["spo_list"][0]["predicate"]
        train_list.append(record)
    
    return train_list


schema_list = load_schema("data/53_schemas.json")
train_raw_list = load_data("data/train_data.json")
val_raw_list = load_data("data/val_data.json")


In [10]:
schema_list

[{'subject_type': '疾病', 'predicate': '预防', 'object_type': '其他'},
 {'subject_type': '疾病', 'predicate': '阶段', 'object_type': '其他'},
 {'subject_type': '疾病', 'predicate': '就诊科室', 'object_type': '其他'},
 {'subject_type': '其他', 'predicate': '同义词（其他/其他）', 'object_type': '其他'},
 {'subject_type': '疾病', 'predicate': '辅助治疗', 'object_type': '其他治疗'},
 {'subject_type': '疾病', 'predicate': '化疗', 'object_type': '其他治疗'},
 {'subject_type': '疾病', 'predicate': '放射治疗', 'object_type': '其他治疗'},
 {'subject_type': '其他治疗',
  'predicate': '同义词（其他治疗/其他治疗）',
  'object_type': '其他治疗'},
 {'subject_type': '疾病', 'predicate': '手术治疗', 'object_type': '手术治疗'},
 {'subject_type': '手术治疗',
  'predicate': '同义词（手术治疗/手术治疗）',
  'object_type': '手术治疗'},
 {'subject_type': '疾病', 'predicate': '实验室检查', 'object_type': '检查'},
 {'subject_type': '疾病', 'predicate': '影像学检查', 'object_type': '检查'},
 {'subject_type': '疾病', 'predicate': '辅助检查', 'object_type': '检查'},
 {'subject_type': '疾病', 'predicate': '组织学检查', 'object_type': '检查'},
 {'subject_type

In [11]:
train_raw_list[0]

{'text': '产后抑郁症@区分产后抑郁症与轻度情绪失调（产后忧郁或“婴儿忧郁”）是重要的，因为轻度情绪失调不需要治疗。',
 'spo_list': [{'Combined': False,
   'object': '轻度情绪失调',
   'object_type': '疾病',
   'predicate': '鉴别诊断',
   'subject': '产后抑郁症',
   'subject_type': '疾病'}]}

In [12]:
train_list, val_list = process_raw_data(train_raw_list), process_raw_data(val_raw_list)

# Check deepseek balace

In [13]:
# url = "https://api.deepseek.com/user/balance"

# payload={}
# headers = {
#   'Accept': 'application/json',
#   'Authorization': f'Bearer {API_KEY}'
# }

# response = requests.request("GET", url, headers=headers, data=payload)

# print(response.text)

In [14]:
from openai import OpenAI

def create_chat_completions(model, system_message, user_message, max_tokens=1024, temperature=0.7, base_url=None):
    client = OpenAI(api_key=OPENAI_API_KEY, base_url=base_url)

    response = client.chat.completions.create(
        model=model,
        messages=[
            {"role": "system", "content": system_message},
            {"role": "user", "content": user_message},
      ],
        max_tokens=max_tokens,
        temperature=temperature,
        stream=False
    )

    return response.choices[0].message.content

result = create_chat_completions(model=OPENAI_MODEL, system_message="You are a helpful assistant", user_message="Hello", base_url=OPENAI_BASE_URL)
print(result)


Hello! How can I assist you today?


In [7]:
import random
from collections import namedtuple

def create_prompt(train_list, val_list, schema_list, sample_size=10):
    Template = namedtuple("Template", ["system_message", "user_message", "label"])

    for row in val_list:
        random_sample = random.choices(train_list, k=sample_size)

        system_message = "你是专门进行实体抽取的专家。请在 schema 中定义的范畴，参考范例中的格式，从 user 给定句子中抽取出符合 schema 定义的实体，不存在的实体类型返回空列表。请按照JSON字符串的格式回答。\n\n"

        system_message += "Schema: \n"
        for schema in schema_list:
            system_message += str(schema) + "\n"
        
        system_message += "\nSamples: \n"
        for record in random_sample:
            system_message += f"{record['text']} ==> {str(record['spo'])}" + "\n"

        user_message = row['text'] + " ==> "

        template = Template(
            system_message=system_message, 
            user_message=user_message,
            label=row['spo'])

        yield template



In [None]:
# from tqdm.auto import tqdm

# prompts = create_prompt(train_list, val_list, schema_list, sample_size=20)

# error_case_list = []
# result_list = []

# for template in tqdm(prompts, total=len(val_list)):
#     system_message, user_message, label = template.system_message, template.user_message, template.label

#     result = create_chat_completions(model="deepseek-chat", system_message=system_message, user_message=user_message, base_url="https://api.deepseek.com")

#     try:
#         result_parsed = parse_json_markdown(result)
#     except Exception as e:
#         error_case_list.append((result, label))
#         continue
#     else:
#         result_list.append((result_parsed, label))

#     if len(result_list) >= 3:
#         break

In [8]:
from tqdm.auto import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed

prompts = create_prompt(train_list, val_list, schema_list, sample_size=20)

error_case_list = []
result_list = []

def process_template(template):
    system_message, user_message, label = template.system_message, template.user_message, template.label

    try:
        result = create_chat_completions(model="deepseek-chat", system_message=system_message, user_message=user_message, base_url="https://api.deepseek.com")
        result_parsed = parse_json_markdown(result)
        return (result_parsed, label)
    except Exception as e:
        return ('error', (result, label))

# Define the number of workers for concurrency
num_workers = 5

with ThreadPoolExecutor(max_workers=num_workers) as executor:
    future_to_template = {executor.submit(process_template, template): template for template in prompts}
    
    for future in tqdm(as_completed(future_to_template), total=len(future_to_template)):
        template = future_to_template[future]
        try:
            result = future.result()
            if result[0] == 'error':
                error_case_list.append(result[1])
            else:
                result_list.append(result)
        except Exception as e:
            error_case_list.append(('Exception', str(e)))


# # Printing the results (optional, for debugging)
# print("Results:", result_list)
# print("Errors:", error_case_list)

  0%|          | 0/3585 [00:00<?, ?it/s]

In [9]:
error_rate = len(error_case_list) / len(val_list)


In [10]:
print(error_rate)

0.3788005578800558


In [15]:
import pickle

# File path where the pickle file will be saved
file_path = 'data/val_result.pkl'

# Open the file in write-binary mode and save the list of dictionaries
with open(file_path, 'wb') as file:
    pickle.dump(result_list, file)

print(f"List of dictionaries has been saved to {file_path}")

List of dictionaries has been saved to data/val_result.pkl


In [16]:
# Open the file in read-binary mode and load the list of dictionaries
with open(file_path, 'rb') as file:
    loaded_list_of_dicts = pickle.load(file)
