导入 Python 库&模块并配置运行信息

In [None]:
import math
import numpy as np
import pandas as pd
import os
import random
import codecs
from pathlib import Path

import mindspore
import mindspore.dataset as ds
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.train.model import Model
from mindspore.nn.metrics import Accuracy
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import (
    ModelCheckpoint,
    CheckpointConfig,
    LossMonitor,
    TimeMonitor,
)
from mindspore.ops import operations as ops

from easydict import EasyDict as edict

cfg = edict(
    {
        "name": "moive review",
        "pre_trained": False,
        "num_classes": 2,
        "batch_size": 64,
        "epoch_size": 4,
        "weight_decay": 3e-5,
        "data_path": "./data/",
        "device_target": "CPU",
        "device_id": 0,
        "keep_checkpoint_max": 1,
        "checkpoint_path": "./ckpt/train_textcnn-4_149.ckpt",
        "word_len": 51,
        "vec_length": 40,
    }
)

context.set_context(
    mode=context.GRAPH_MODE, device_target=cfg.device_target, device_id=cfg.device_id
)

数据读取和预处理

In [None]:
with open("./data/rt-polarity.neg", "r", encoding="utf-8") as f:
    print("Negative reviews:")
    for i in range(5):
        print("[{0}]:{1}".format(i, f.readline()))

with open("./data/rt-polarity.pos", "r", encoding="utf-8") as f:
    print("Positive reviews:")
    for i in range(5):
        print("[{0}]:{1}".format(i, f.readline()))

定义数据处理函数代码

In [None]:
# 定义数据处理函数代码
class Generator:
    def __init__(self, input_list):
        self.input_list = input_list

    def __getitem__(self, item):
        return (
            np.array(self.input_list[item][0], dtype=np.int32),
            np.array(self.input_list[item][1], dtype=np.int32),
        )

    def __len__(self):
        return len(self.input_list)


class MovieReview:
    """
    影评数据集
    """

    def __init__(self, root_dir, maxlen, split):
        self.path = root_dir
        self.feelMap = {"neg": 0, "pos": 1}
        self.files = []

        self.doConvert = False

        mypath = Path(self.path)
        if not mypath.exists() or not mypath.is_dir():
            raise ValueError("please check the root_dir!")

        # 在数据目录下寻找文件
        for root, _, filename in os.walk(self.path):
            for each in filename:
                self.files.append(os.path.join(root, each))
            break

        # 确认是否为两个文件.neg和.pos
        if len(self.files) != 2:
            raise ValueError(
                "There are {} files in the root_dir".format(len(self.files))
            )

        # 读取数据
        self.word_num = 0
        self.maxlen = 0
        self.minlen = float("inf")
        self.maxlen = float("-inf")

        self.Pos = []
        self.Neg = []
        for filename in self.files:
            self.read_data(filename)

        self.text2vec(maxlen=maxlen)
        self.split_data(split=split)

    def read_data(self, filePath):
        with open(filePath, "r") as f:
            for sentence in f.readlines():
                sentence = (
                    sentence.replace("\n", "")
                    .replace('"', "")
                    .replace("'", "")
                    .replace(".", "")
                    .replace(",", "")
                    .replace("[", "")
                    .replace("]", "")
                    .replace("(", "")
                    .replace(")", "")
                    .replace(":", "")
                    .replace("--", "")
                    .replace("-", "")
                    .replace("\\", "")
                    .replace("0", "")
                    .replace("1", "")
                    .replace("2", "")
                    .replace("3", "")
                    .replace("4", "")
                    .replace("5", "")
                    .replace("6", "")
                    .replace("7", "")
                    .replace("8", "")
                    .replace("9", "")
                    .replace("`", "")
                    .replace("=", "")
                    .replace("$", "")
                    .replace("/", "")
                    .replace("*", "")
                    .replace(";", "")
                    .replace("<b>", "")
                    .replace("%", "")
                )
                # 为什么不用正则？
                # import re
                # sentence = re.sub(r'[\n"\'.,\[\]\(\):\\\-0-9`=/$*;%<b>%]', '', sentence)
                sentence = sentence.split(" ")
                sentence = list(filter(lambda x: x, sentence))
                if sentence:
                    self.word_num += len(sentence)
                    self.maxlen = max(self.maxlen, len(sentence))
                    self.minlen = min(self.minlen, len(sentence))
                    if "pos" in filePath:
                        self.Pos.append([sentence, self.feelMap["pos"]])
                    else:
                        self.Neg.append([sentence, self.feelMap["neg"]])

    def text2vec(self, maxlen):
        self.Vocab = dict()

        for SentenceLabel in self.Pos + self.Neg:
            vector = [0] * maxlen
            for index, word in enumerate(SentenceLabel[0]):
                if index >= maxlen:
                    break
                if word not in self.Vocab.keys():
                    self.Vocab[word] = len(self.Vocab)
                    vector[index] = len(self.Vocab) - 1
                else:
                    vector[index] = self.Vocab[word]
            SentenceLabel[0] = vector
        self.doConvert = True

    def split_dataset(self, split):
        trunk_pos_size = math.ceil((1 - split) * len(self.Pos))
        trunk_neg_size = math.ceil((1 - split) * len(self.Neg))
        trunk_num = int(1 / (1 - split))
        pos_temp = list()
        neg_temp = list()
        for index in range(trunk_num):
            pos_temp.append(
                self.Pos[index * trunk_pos_size : (index + 1) * trunk_pos_size]
            )
            neg_temp.append(
                self.Neg[index * trunk_neg_size : (index + 1) * trunk_neg_size]
            )
        self.test = pos_temp.pop(2) + neg_temp.pop(2)
        self.train = [i for item in pos_temp + neg_temp for i in item]

        random.shuffle(self.train)

    def get_dict_len(self):
        if self.doConvert:
            return len(self.Vocab)
        else:
            print("Haven't finished Text2Vec!")
            return -1

    def create_train_dataset(self, epoch_size, batch_size):
        dataset = ds.GeneratorDataset(
            source=Generator(input_list=self.train),
            column_names=["data", "label"],
            shuffle=False,
        )
        dataset = dataset.batch(batch_size=batch_size, drop_remainder=True)
        dataset = dataset.repeat(epoch_size)
        return dataset

    def create_test_dataset(self, batch_size):
        dataset = ds.GeneratorDataset(
            source=Generator(input_list=self.test),
            column_names=["data", "label"],
            shuffle=False,
        )
        dataset = dataset.batch(batch_size=batch_size, drop_remainder=True)
        return dataset

In [None]:
instance = MovieReview(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9)
dataset = instance.create_train_dataset(
    batch_size=cfg.batch_size, epoch_size=cfg.epoch_size
)
batch_num = dataset.get_dataset_size()

显示数据处理结果代码

In [None]:
vocab_size = instance.get_dict_len()
print("vocab_size: ", vocab_size)
item = dataset.create_dict_iterator()
for i, data in enumerate(item):
    if i < 1:
        print(data)
        print(data["data"][1])
    else:
        break

配置训练参数

In [None]:
learning_rate = []
warm_up = [
    1e-3 / math.floor(cfg.epoch_size / 5) * (i + 1)
    for _ in range(batch_num)
    for i in range(math.floor(cfg.epoch_size / 5))
]
shrink = [
    1e-3 / (16 * (i + 1))
    for _ in range(batch_num)
    for i in range(math.floor(cfg.epoch_size * 3 / 5))
]
normal_run = [
    1e-3
    for _ in range(batch_num)
    for i in range(
        cfg.epoch_size
        - math.floor(cfg.epoch_size / 5)
        - math.floor(cfg.epoch_size * 2 / 5)
    )
]
learning_rate = learning_rate + warm_up + normal_run + shrink

TextCNN 模型定义

In [None]:
def _weight_variable(shape, factor=0.01):
    init_value = np.random.randn(*shape).astype(np.float32) * factor
    return Tensor(init_value)


def make_conv_layer(kernel_size):
    weight_shape = (96, 1, *kernel_size)
    weight = _weight_variable(weight_shape)
    return nn.Conv2d(
        in_channels=1,
        out_channels=96,
        kernel_size=kernel_size,
        padding=1,
        pad_mode="pad",
        weight_init=weight,
        has_bias=True,
    )


class TextCNN(nn.Cell):
    def __init__(self, vocab_len, word_len, num_classes, vec_length):
        super(TextCNN, self).__init__()
        self.vec_length = vec_length
        self.word_len = word_len
        self.num_classes = num_classes

        self.unsqueeze = ops.ExpandDims()
        self.embedding = nn.Embedding(
            vocab_len, self.vec_length, embedding_table="normal"
        )

        self.slice = ops.Slice()
        self.layer1 = make_conv_layer(kernel_size=3)
        self.layer2 = make_conv_layer(kernel_size=4)
        self.layer3 = make_conv_layer(kernel_size=5)

        self.concat = ops.Concat(1)

        self.fc = nn.Dense(96 * 3, self.num_classes)
        self.drop = nn.Dropout(keep_prob=0.5)
        self.print = ops.Print()
        self.reducemean = ops.ReduceMax(keep_dims=False)

    def make_layer(self, kernel_height):
        return nn.SequentialCell(
            [
                make_conv_layer((kernel_height, self.vec_length)),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=(self.word_len - kernel_height + 1, 1)),
            ]
        )

    def construct(self, x):
        x = self.unsqueeze(x, 1)
        x = self.embedding(x)
        x1 = self.layer1(x)
        x2 = self.layer2(x)
        x3 = self.layer3(x)

        x1 = self.reducemean(x1, (2, 3))
        x2 = self.reducemean(x2, (2, 3))
        x3 = self.reducemean(x3, (2, 3))

        x = self.concat((x1, x2, x3))
        x = self.drop(x)
        x = self.fc(x)
        return x


net = TextCNN(
    vocab_len=instance.get_dict_len(),
    word_len=cfg.word_len,
    num_classes=cfg.num_classes,
    vec_length=cfg.vec_length,
)
print(net)

定义训练的相关参数

In [None]:
# 优化器、损失函数、保存检查点、时间监视器等设置
opt = nn.Adam(
    filter(lambda x: x.requires_grad, net.get_parameters()),
    learning_rate=learning_rate,
    weight_decay=cfg.weight_decay,
)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
model = Model(net, loss_fn=loss, optimizer=opt, metrics={"acc": Accuracy()})
config_ck = CheckpointConfig(
    save_checkpoint_steps=int(cfg.epoch_size * batch_num / 2),
    keep_checkpoint_max=cfg.keep_checkpoint_max,
)
time_cb = TimeMonitor(data_size=batch_num)
ckpt_save_dir = "./ckpt"
ckpoint_cb = ModelCheckpoint(
    prefix="train_textcnn", directory=ckpt_save_dir, config=config_ck
)
loss_cb = LossMonitor()

启动训练

In [None]:
model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb])
print("train success")

测试评估

In [None]:
def preprocess(sentence):
    sentence = sentence.lower().strip()
    sentence = (
        sentence.replace("\n", "")
        .replace('"', "")
        .replace("'", "")
        .replace(".", "")
        .replace(",", "")
        .replace("[", "")
        .replace("]", "")
        .replace("(", "")
        .replace(")", "")
        .replace(":", "")
        .replace("--", "")
        .replace("-", "")
        .replace("\\", "")
        .replace("0", "")
        .replace("1", "")
        .replace("2", "")
        .replace("3", "")
        .replace("4", "")
        .replace("5", "")
        .replace("6", "")
        .replace("7", "")
        .replace("8", "")
        .replace("9", "")
        .replace("`", "")
        .replace("=", "")
        .replace("$", "")
        .replace("/", "")
        .replace("*", "")
        .replace(";", "")
        .replace("<b>", "")
        .replace("%", "")
        .replace("  ", " ")
    )
    sentence = sentence.split(" ")
    maxlen = cfg.word_len
    vector = [0] * maxlen
    for index, word in enumerate(sentence):
        if index >= maxlen:
            break
        if word not in instance.Vocab.keys():
            print("word {} not in vocab".format(word))
        else:
            vector[index] = instance.Vocab[word]
    sentence = vector
    
    return sentence

def inference(revire_en):
    revire_en = preprocess(revire_en)
    input_en = Tensor(np.array(revire_en), np.int32)
    output = net(input_en)
    if np.argmax(np.array(output[0])) == 1:
        print("Positive comments")
    else:
        print("Negative comments")

In [None]:
revire_en = "the movie is so boring"
inference(revire_en)