In [1]:
import json
import numpy as np
import tensorflow as tf
import keras
from keras import Input
from keras.models import Model
from keras.layers import Dense, Dropout
from keras.utils import to_categorical
import requests

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
Using TensorFlow backend.


In [2]:
with open("../question_generate/clf_data.json","r",encoding="utf-8") as f:
    data = json.load(f)
train_data = data["train_data"]
train_labels = data["train_labels"]
labels = list(set(train_labels))
label2id = {label:idx for idx,label in enumerate(labels)}
id2label = {idx:label for label,idx in label2id.items()}
train_ids = [label2id[label] for label in train_labels]

In [3]:
from bert_serving.client import BertClient
bc = BertClient()

In [4]:
train_vectors = bc.encode(train_data)
train_onehot_ids = to_categorical(train_ids)
embedding_dim = 768
dense_units = 128
dropout_rate = 0.5
output_categories = len(label2id)

inputs = Input(shape=(embedding_dim,))
dense_in = Dense(dense_units,activation="relu")(inputs)
dropout = Dropout(dropout_rate)(dense_in)
dense_out = Dense(dense_units,activation="relu")(dropout)
outputs = Dense(output_categories,activation="softmax")(dense_out)
model = Model(inputs,outputs)
model.compile(optimizer=keras.optimizers.Adam(lr=0.001,beta_1=0.9,beta_2=0.999,epsilon=1e-8),
              loss="categorical_crossentropy",
              metrics=["acc"])

Instructions for updating:
Colocations handled automatically by placer.


In [5]:
model.fit(train_vectors,train_onehot_ids,batch_size=64,epochs=15)

Instructions for updating:
Use tf.cast instead.
Epoch 1/15
Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 8/15
Epoch 9/15
Epoch 10/15
Epoch 11/15
Epoch 12/15
Epoch 13/15
Epoch 14/15
Epoch 15/15


<keras.callbacks.callbacks.History at 0x7fa0ef56f438>

In [6]:
def get_P(text):
    return id2label[model.predict(bc.encode([text])).argmax()]

In [7]:
def get_S(text):
    url = "http://0.0.0.0:5788/?s="+text
    response = json.loads(requests.get(url).text)
    return response['entities'].split('|')

In [8]:
from py2neo import Graph,Node,NodeMatcher,RelationshipMatcher
class NeoGraph:
    def __init__(self):
        self.g = Graph(
            host="10.15.82.65",
            port=7687,
            user="neo4j",
        )
        self.matcher = NodeMatcher(self.g)
        self.re_matcher = RelationshipMatcher(self.g)

    def getNode(self, key, label=None):
        if label is None:
            return self.matcher.match(name = key).first()
        else:
            return self.matcher.match(label, name = key).first()

handler = NeoGraph()

In [9]:
def sym2dis(syms):
    res = set()
    flag = 1
    for sym in syms:
        node = handler.getNode(sym, "症状")
        if node is None:
            return set()
        tem = set()
        for i in handler.re_matcher.match(nodes=(None,node), r_type='related_sym'):
            if str(i.start_node.labels) == ":疾病":
                tem.add(i.start_node['name'])
        if flag == 1:
            res = tem
            flag = 0
        else:
            res = res & tem
    return res

#药 = 饮片、药材、方剂
#这个的问法是什么药能治多种症状，取交集
def sym2drug(syms):
    res = set()
    flag = 1
    for sym in syms:
        node = handler.getNode(sym, "症状")
        if node is None:
            return set()
        tem = set()
        for i in handler.re_matcher.match(nodes=(None,node), r_type='treat'):
            if str(i.start_node.labels) == ":饮片" or str(i.start_node.labels) == ":药材" or str(i.start_node.labels) == ":方剂":
                tem.add(i.start_node['name'])
        if flag == 1:
            res = tem
            flag = 0
        else:
            res = res & tem
    return res

#这个的问法是得了病吃什么，取并集
def dis2drug(dises):
    res = set()
    for dis in dises:
        node = handler.getNode(dis, "疾病")
        for i in handler.re_matcher.match(nodes=(None,node), r_type='treat'):
            if str(i.start_node.labels) == ":饮片" or str(i.start_node.labels) == ":药材" or str(i.start_node.labels) == ":方剂":
                res.add(i.start_node['name'])
    return res

def drug2dis(drugs):
    res = set()
    flag = 1
    for drug in drugs:
        node = handler.getNode(drug, "药材")
        if node is None:
            node = handler.getNode(drug, "方剂")
        if node is None:
            node = handler.getNode(drug, "饮片")
        if node is None:
            return set()
        tem = set()
        for i in handler.re_matcher.match(nodes=(node,None), r_type='treat'):
            if str(i.end_node.labels) == ":疾病" or str(i.end_node.labels) == ":症状" or str(i.end_node.labels) == ":症候":
                tem.add(i.end_node['name'])
        if flag == 1:
            res = tem
            flag = 0
        else:
            res = res & tem
    return res

def ypcomp(yps):
    res1 = {}
    res2 = {}
    node1 = handler.getNode(yps[0], "饮片")
    node2 = handler.getNode(yps[1], "饮片")
    if node1 is None or node2 is None:
        return set()
    for i in handler.re_matcher.match(nodes=(node1,None), r_type='function'):
        res1.setdefault('功效', []).append(i.end_node['name'])
    for i in handler.re_matcher.match(nodes=(node1,None), r_type='property'):
        res1.setdefault('性', []).append(i.end_node['name'])
    for i in handler.re_matcher.match(nodes=(node1,None), r_type='flavor'):
        res1.setdefault('味', []).append(i.end_node['name'])
    for i in handler.re_matcher.match(nodes=(node2,None), r_type='function'):
        res2.setdefault('功效', []).append(i.end_node['name'])
    for i in handler.re_matcher.match(nodes=(node2,None), r_type='property'):
        res2.setdefault('性', []).append(i.end_node['name'])
    for i in handler.re_matcher.match(nodes=(node2,None), r_type='flavor'):
        res2.setdefault('味', []).append(i.end_node['name'])
    return res1, res2

def drugcompo(drugs):
    res1 = set()
    res2 = set()
    node1 = handler.getNode(drugs[0], "方剂")
    node2 = handler.getNode(drugs[1], "方剂")
    if node1 is None or node2 is None:
        return set()
    for i in handler.re_matcher.match(nodes=(None, node1), r_type='compose'):
        res1.add(i.start_node['name'])
    for i in handler.re_matcher.match(nodes=(None, node2), r_type='compose'):
        res2.add(i.start_node['name'])
    return res1 & res2

In [10]:
function = {}
function['症状到病'] = sym2dis
function['症状到药'] = sym2drug
function['疾病到药'] = dis2drug
function['药到病'] = drug2dis
function['饮片比较'] = ypcomp
function['方剂组成'] = drugcompo

def KBQA(text):
    clf = get_P(text)
    spo = get_S(text)
    print(clf,spo)
    try:
        res = function[clf](spo)
    except:
        res = set()
    return res

In [11]:
text = "茜草汤和解毒汤有什么共同点"
KBQA(text)

方剂组成 ['茜草汤', '解毒汤']


{'皂角刺', '金银花'}

In [12]:
text='我心悸,汗出偏沮，是得了什么病'
print(get_P(text))
print(get_S(text))
print(sym2dis(get_S(text)))

症状到病
['心悸', '汗出偏沮']
{'性感异常', '焦虑性神经症', '阴道后壁脱垂', '性厌恶'}


In [13]:
text='我产后腹痛、骨痛而且小便不通，该吃什么'
print(get_P(text))
print(get_S(text))
print(sym2drug(get_S(text)))

症状到药
['产后腹痛', '骨痛', '小便不通']
{'芍药汤', '九节茶', '没药散'}


In [14]:
text='我得了淋病和痢疾，给我推荐点药'
print(get_P(text))
print(get_S(text))
print(list(dis2drug(get_S(text)))[:10]) #并集太多了，取一部分

疾病到药
['淋病', '痢疾']
['万应灵膏', '昆明鸡脚黄连', '地磨薯', '石椒草', '细叶桉', '鄂西天胡荽', '普洱茶', '杜楝', '疳积饼', '镇国将军丸']


In [15]:
text='草藤乌和牡蛎能治什么'
print(get_P(text))
print(get_S(text))
print(drug2dis(get_S(text)))

药到病
['草藤乌', '牡蛎']
{'头痛', '耳疮'}


In [16]:
text='川木通和九头草味道如何'
print(get_P(text))
print(get_S(text))
print(ypcomp(get_S(text)))

饮片比较
['川木通', '九头草']
({'功效': ['利水', '除烦', '清热', '通淋', '清心', '通经', '活血', '利尿', '通脉', '利便', '通乳', '下乳'], '性': ['寒'], '味': ['微苦']}, {'功效': ['止血', '利湿', '通经', '清热', '利尿', '活血', '破血', '调经'], '性': ['平'], '味': ['苦']})


In [None]:
text='茜草汤和解毒汤有什么共同点'
print(get_P(text))
print(get_S(text))
print(drugcompo(get_S(text)))