In [20]:
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
from transformers import BertTokenizer, BertModel
import pandas as pd
import re
import jieba

In [14]:
BERT_PATH = 'bert-base-chinese/'

In [25]:
class Model(nn.Module):

    def __init__(self):
        super(Model, self).__init__()
        self.bert_tokenizer = BertTokenizer.from_pretrained("C:/Users/12968/Desktop/chinese-bert-wwm-ext")
        self.bert = BertModel.from_pretrained("C:/Users/12968/Desktop/chinese-bert-wwm-ext")
        for param in self.bert.parameters():
            param.requires_grad = True
        self.fc = nn.Linear(128, 2)

    def forward(self, sentence_lists):
        """
        输入句子列表(去掉了停用词的)
        """
        sentence_lists = [' '.join(x) for x in sentence_lists]
        print(sentence_lists)
        ids = self.bert_tokenizer(sentence_lists, padding=True, return_tensors="pt")
        print(ids)
        inputs = ids['input_ids']
        print(inputs)

        embeddings = self.bert(inputs)
        return embeddings[0]

In [38]:
class Pre:
    def __init__(self, text):
        """
        输入一个文本
        """
        self.puncs_coarse = ['。', '!', '；', '？', '……', '\n',' ']
        self.text = text
        self.stopwords = self.deal_wrap('dict/stop1205.txt')
    
    def segment(self, sentence):
        sentence_seged = jieba.cut(sentence.strip())
        outstr = ''
        for word in sentence_seged:
            if word not in stopwords:
                if word != '\t':
                    outstr += word
                    outstr += " "
        word_list = outstr.split(' ')
        pattern = '[A-Za-z]*[0-9]*[\'\"\%.\s\@\!\#\$\^\&\*\(\)\-\<\>\?\/\,\~\`\:\;]*[：；”“ ‘’+-——！，。？、~@#￥%……&*（）【】]*'
        t = [re.sub(pattern, "", x.strip()) for x in word_list]
        t = [x for x in t if x != '']
        return ''.join(t)
    
    def deal_wrap(self, filedict):
        temp = []
        for x in open(filedict, 'r', encoding='utf-8').readlines():
            temp.append(x.strip())
        return temp
        
    def split_sentence_coarse(self):
        """
        按照。！？“”等中文完整句子语义来分句
        1. 去除换行符、多余的空格、百分号
        2. 分句，存入列表
        :return:装着每个句子的列表（包括标点符号）
        """
        text = self.text
        sentences = []
        start = 0
        for i in range(len(text)):
            if text[i] in self.puncs_coarse:
                sentences.append(text[start:i + 1])
                start = i + 1
        return sentences
    
    def get_keywords(self, data):
        """
        如果句子太长，就进行关键词提取
        """
        from jieba import analyse
        textrank = analyse.textrank
        keywords = textrank(data, topK=8)
        return ''.join(keywords)

    def preprocess(self):
        # 分句
        sentences = self.split_sentence_coarse()
        # 对每个句子，去除里面的停用词，再连起来
        # 对每个句子，如果句子太长，长度大于20（我随便定的），就抽取八个关键词连起来
        new_sent = []
        for i in sentences:
            i = self.segment(i)
            if len(i) > 20:
                i = self.get_keywords(i)
            if i != '':
                new_sent.append(i)
        return new_sent