In [10]:
import re
import jieba
import pickle
from sklearn.feature_extraction.text import CountVectorizer
import joblib


# 分词
def tokenize(text):
    """
    带有语料清洗功能的分词函数
    """
    text = re.sub("\{%.+?%\}", " ", text)  # 去除 {%xxx%} (地理定位, 微博话题等)
    text = re.sub("@.+?( |$)", " ", text)  # 去除 @xxx (用户名)
    text = re.sub("【.+?】", " ", text)  # 去除 【xx】 (里面的内容通常都不是用户自己写的)
    icons = re.findall("\[.+?\]", text)  # 提取出所有表情图标
    text = re.sub("\[.+?\]", "IconMark", text)  # 将文本中的图标替换为`IconMark`

    tokens = []
    for k, w in enumerate(jieba.lcut(text)):
        w = w.strip()
        if "IconMark" in w:  # 将IconMark替换为原图标
            for i in range(w.count("IconMark")):
                tokens.append(icons.pop(0))
        elif w and w != '\u200b' and w.isalpha():  # 只保留有效文本
            tokens.append(w)
    return tokens


pp = './data'
# 加载模型
model_path0 = pp+'/bayes_01.02.pkl'
model0 = joblib.load(model_path0)
# 加载词典
vec_path0 = pp+'/feature_01.02.pkl'
vec0 = CountVectorizer(decode_error="replace", vocabulary=pickle.load(open(vec_path0, "rb")))


# 处理数据
def getType(string_l):
    type_=['ICT', '新能源汽车', '生物医药', '医疗器械', '钢铁', 
       '能源', '工业机器人', '先进轨道交通', '数控机床',  '工业软件', 
       '高端装备', '半导体', '人工智能', '稀土']
    
    if isinstance(string_l,str):
        string_l=[string_l]
    if isinstance(string_l,list):
        X_data = []
        for string in string_l:
            X_data.append(" ".join(tokenize(string)))

        vecc = vec0.transform(X_data)
        result_pre = model0.predict(vecc)
        result_pre_proba = model0.predict_proba(vecc)
        
        res=[]
        for i in range(len(string_l)):
            sin_res=[]
            end_index=result_pre[i]
            end_score=result_pre_proba[i][end_index]
            if end_score<0.75:
                sin_res.append('其他')
                sin_res.append(-1)
            else:
                sin_res.append(type_[end_index])
                sin_res.append(end_score)
            sin_res.append([type_[j]+':'+str(result_pre_proba[i][j]) for j in range(len(type_))])
            res.append(sin_res)
        return res
    else:
        print('error')


if __name__ == '__main__':
    string = ["2018沃德十佳发动机公布，国内可以买到哪些车型？","2018沃德十佳发动机公布，国内可以买到哪些车型？"]
    result=getType(string)
    print(result)
    all_res=[i[0] for i in result]
    print(all_res)



[['新能源汽车', 0.9867690160541083, ['ICT:0.002260762034060226', '新能源汽车:0.9867690160541083', '生物医药:0.000969709968777577', '医疗器械:0.00036090812561861805', '钢铁:8.068983700896634e-05', '能源:0.00827086053231211', '工业机器人:0.000389800014411476', '先进轨道交通:1.3028017159287019e-05', '数控机床:0.00016944309633607987', '工业软件:0.0002301898648119734', '高端装备:3.660549376338843e-05', '半导体:6.920364388501734e-06', '人工智能:2.739467241641826e-05', '稀土:0.00041467192482425264']], ['新能源汽车', 0.9867690160541083, ['ICT:0.002260762034060226', '新能源汽车:0.9867690160541083', '生物医药:0.000969709968777577', '医疗器械:0.00036090812561861805', '钢铁:8.068983700896634e-05', '能源:0.00827086053231211', '工业机器人:0.000389800014411476', '先进轨道交通:1.3028017159287019e-05', '数控机床:0.00016944309633607987', '工业软件:0.0002301898648119734', '高端装备:3.660549376338843e-05', '半导体:6.920364388501734e-06', '人工智能:2.739467241641826e-05', '稀土:0.00041467192482425264']]]
['新能源汽车', '新能源汽车']
