In [134]:
from pprint import pprint
from paddlenlp import Taskflow

class SlotFiller:
    def __init__(self):
        self.slot = {
            "intent":None,
            "intent_score":None,
            "entitys":{}
        }
        self.entity_schema = ['药物', '症状', '疾病'] 
        self.intent_schema = ["疾病的防御措施", "疾病产生的原因", "患病该吃什么药","疾病的症状","药物能治疗什么病","其他"]
        self.entity_extract = Taskflow('information_extraction', schema=self.entity_schema)
        self.intent_cls = Taskflow("zero_shot_text_classification", model="utc-base", schema=self.intent_schema)
    def get_slot(self,text):
        entitys = self.entity_extract(text)
        intent =  self.intent_cls(text)
        if intent is not None and len(intent[0]["predictions"]) > 0:
            if  self.slot["intent"] != intent[0]["predictions"][0]["label"]:
                self.slot["intent"] = intent[0]["predictions"][0]["label"]
                self.slot["intent_score"] = intent[0]["predictions"][0]["score"]
        if entitys is not None:
            for i in entitys[0].items():
                 self.slot["entitys"][i[0]] = {i[1][0]["text"]:{i[1][0]["probability"]}}
        return  self.slot



class SqlParser:
    def __init__(self):
        pass

    def parse(self,slot):
        if slot["intent"] == "其他":
            return 
        if slot["intent"] == "疾病的症状":
            if '疾病' in slot["entitys"].keys():
                entity =  list(slotins["entitys"]['疾病'].keys())[0]
                return "MATCH (m:Disease)-[r:has_symptom]->(n:Symptom) where m.name = '{}' return m.name, n.name".format(entity)
        if slot["intent"] == "疾病的防御措施":
            if '疾病' in slot["entitys"].keys():
                entity =  list(slotins["entitys"]['疾病'].keys())[0]
                return "MATCH (m:Disease) where m.name = '{}' return m.name, m.prevent".format(entity)
        if slot["intent"] == "疾病产生的原因":
             if '疾病' in slot["entitys"].keys():
                entity =  list(slotins["entitys"]['疾病'].keys())[0]
                return "MATCH (m:Disease) where m.name = '{}' return m.name, m.cause".format(entity)
        if slot["intent"] == "患病该吃什么药":
             if '疾病' in slot["entitys"].keys():
                entity =  list(slotins["entitys"]['疾病'].keys())[0]
                return "MATCH (m:Disease)-[r:common_drug]->(n:Drug) where m.name = '{}' return m.name, n.name".format(entity)
        if slot["intent"] == "药物能治疗什么病":
             if '药物' in slot["entitys"].keys():
                entity =  list(slotins["entitys"]['药物'].keys())[0]
                return "MATCH (m:Disease)-[r:common_drug]->(n:Drug) where n.name = '{}' return m.name, r.name, n.name".format(entity)

In [135]:
slotf = SlotFiller()

[32m[2023-11-06 15:25:24,078] [    INFO][0m - We are using <class 'paddlenlp.transformers.ernie.tokenizer.ErnieTokenizer'> to load 'C:\Users\131655\.paddlenlp\taskflow\information_extraction\uie-base'.[0m
[32m[2023-11-06 15:25:25,101] [    INFO][0m - We are using <class 'paddlenlp.transformers.ernie.tokenizer.ErnieTokenizer'> to load 'C:\Users\131655\.paddlenlp\taskflow\zero_shot_text_classification\utc-base'.[0m
[32m[2023-11-06 15:25:25,135] [    INFO][0m - Assigning ['[O-MASK]'] to the additional_special_tokens key of the tokenizer[0m


In [148]:
slotins = slotf.get_slot("我高血压，该吃啥药")
print(slotins)
sqlp = SqlParser()
sql = sqlp.parse(slotins)
print(sql)

{'intent': '患病该吃什么药', 'intent_score': 0.9446698357853085, 'entitys': {'疾病': {'高血压': {0.8261676259649207}}, '药物': {'厄贝沙坦片': {0.969891415250288}}}}
MATCH (m:Disease)-[r:common_drug]->(n:Drug) where m.name = '高血压' return m.name, n.name


In [137]:
slotins = slotf.get_slot("高血压呢")
print(slotins)
sqlp = SqlParser()
sql = sqlp.parse(slotins)
print(sql)

{'intent': '疾病的症状', 'intent_score': 0.8013479716695394, 'entitys': {'疾病': {'高血压': {0.7640298965737031}}}}
MATCH (m:Disease)-[r:has_symptom]->(n:Symptom) where m.name = '高血压' return m.name, n.name


In [139]:
slotins = slotf.get_slot("急性胃炎吃什么药")
print(slotins)
sqlp = SqlParser()
sql = sqlp.parse(slotins)
print(sql)

{'intent': '患病该吃什么药', 'intent_score': 0.9286933121028708, 'entitys': {'疾病': {'急性胃炎': {0.9604805477728178}}}}
MATCH (m:Disease)-[r:common_drug]->(n:Drug) where m.name = '急性胃炎' return m.name, n.name


In [161]:
slotins = slotf.get_slot("富马酸喹硫平片可以治疗什么病")
print(slotins)
sqlp = SqlParser()
sql = sqlp.parse(slotins)
print(sql)

{'intent': '药物能治疗什么病', 'intent_score': 0.8591247398489691, 'entitys': {'疾病': {'高血压': {0.8261676259649207}}, '药物': {'富马酸喹硫平片': {0.9567390025445448}}}}
MATCH (m:Disease)-[r:common_drug]->(n:Drug) where n.name = '富马酸喹硫平片' return m.name, r.name, n.name


In [162]:
from py2neo import Graph,Node
g = Graph("bolt://localhost:7687", auth=("neo4j", "12345678"))

In [163]:
g.run(sql)

m.name,r.name,n.name
旅途精神病,常用药品,富马酸喹硫平片
露阴癖,常用药品,富马酸喹硫平片
急性应激反应,常用药品,富马酸喹硫平片


In [164]:
sql

"MATCH (m:Disease)-[r:common_drug]->(n:Drug) where n.name = '富马酸喹硫平片' return m.name, r.name, n.name"