In [16]:
import re
import jieba
import pickle
from sklearn.feature_extraction.text import CountVectorizer
import joblib


# 分词
def tokenize(text):
    """
    带有语料清洗功能的分词函数
    """
    text = re.sub("\{%.+?%\}", " ", text)  # 去除 {%xxx%} (地理定位, 微博话题等)
    text = re.sub("@.+?( |$)", " ", text)  # 去除 @xxx (用户名)
    text = re.sub("【.+?】", " ", text)  # 去除 【xx】 (里面的内容通常都不是用户自己写的)
    icons = re.findall("\[.+?\]", text)  # 提取出所有表情图标
    text = re.sub("\[.+?\]", "IconMark", text)  # 将文本中的图标替换为`IconMark`

    tokens = []
    for k, w in enumerate(jieba.lcut(text)):
        w = w.strip()
        if "IconMark" in w:  # 将IconMark替换为原图标
            for i in range(w.count("IconMark")):
                tokens.append(icons.pop(0))
        elif w and w != '\u200b' and w.isalpha():  # 只保留有效文本
            tokens.append(w)
    return tokens


pp = '../result5'
# 加载模型
model_path0 = pp+'/bayes_01.02.pkl'
model0 = joblib.load(model_path0)
# 加载词典
vec_path0 = pp+'/feature_01.02.pkl'
vec0 = CountVectorizer(decode_error="replace", vocabulary=pickle.load(open(vec_path0, "rb")))


# 处理数据
def getType(string_l):
    type_=['ICT', '新能源汽车', '生物医药', '医疗器械', '钢铁', 
           '能源', '工业机器人', '先进轨道交通', '半导体', '高端设备', 
           '工业软件', '人工智能', '数控机床', '稀土']
    
    if isinstance(string_l,str):
        string_l=[string_l]
    if isinstance(string_l,list):
        X_data = []
        for string in string_l:
            X_data.append(" ".join(tokenize(string)))

        vecc = vec0.transform(X_data)
        result_pre = model0.predict(vecc)
        result_pre_proba = model0.predict_proba(vecc)
        
        res=[]
        for i in range(len(string_l)):
            sin_res=[]
            end_index=result_pre[i]
            end_score=result_pre_proba[i][end_index]
            if end_score<0.75:
                sin_res.append('其他')
                sin_res.append(-1)
            else:
                sin_res.append(type_[end_index])
                sin_res.append(end_score)
            sin_res.append([type_[j]+':'+str(result_pre_proba[i][j]) for j in range(len(type_))])
            res.append(sin_res)
        return res
    else:
        print('error')


if __name__ == '__main__':
    string = ["芯云半导体高端集成电路测试基地结顶，致力于打造世界一流集成电路测试服务基地"]
    result=getType(string)
    print(result)
    all_res=[i[0] for i in result]
    print(all_res)



[['半导体', 0.9999970464898087, ['ICT:8.36477289793891e-07', '新能源汽车:7.363009661237748e-07', '生物医药:2.3720168900728588e-15', '医疗器械:6.209759798358118e-14', '钢铁:3.63451191304875e-11', '能源:4.5962322164625324e-10', '工业机器人:1.2353308196857472e-06', '先进轨道交通:3.3459642999070765e-13', '半导体:0.9999970464898087', '高端设备:2.531032719759328e-13', '工业软件:2.852759090800844e-12', '人工智能:2.8154402188767964e-12', '数控机床:2.2173056050835066e-13', '稀土:1.448986030607285e-07']]]
['半导体']
