In [6]:
import os
import argparse
import mindspore.communication.management as D
from mindspore.communication.management import get_rank
import mindspore.common.dtype as mstype
from mindspore import context
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
#from mindspore.train.train_thor import ConvertModelUtils
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay#, THOR
from mindspore import log as logger
from mindspore.common import set_seed
from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \
                BertTrainAccumulationAllReduceEachWithLossScaleCell, \
                BertTrainAccumulationAllReducePostWithLossScaleCell, \
                BertTrainOneStepWithLossScaleCellForAdam, \
                AdamWeightDecayForBert
from src.dataset import create_bert_dataset
from src.config import cfg, bert_net_cfg
from src.utils import LossCallBack, BertLearningRate
from mindspore.nn.metrics import Metric
import mindspore.nn as nn


In [2]:
net_with_loss = BertNetworkWithLoss(bert_net_cfg, True)
param_dict = load_checkpoint('2.ckpt')
load_param_into_net(net_with_loss, param_dict)



['bert.bert.bert_encoder.layers.0.attention.attention.query_layer.weight',
 'bert.bert.bert_encoder.layers.0.attention.attention.query_layer.bias',
 'bert.bert.bert_encoder.layers.0.attention.attention.key_layer.weight',
 'bert.bert.bert_encoder.layers.0.attention.attention.key_layer.bias',
 'bert.bert.bert_encoder.layers.0.attention.attention.value_layer.weight',
 'bert.bert.bert_encoder.layers.0.attention.attention.value_layer.bias',
 'bert.bert.bert_encoder.layers.1.attention.attention.query_layer.weight',
 'bert.bert.bert_encoder.layers.1.attention.attention.query_layer.bias',
 'bert.bert.bert_encoder.layers.1.attention.attention.key_layer.weight',
 'bert.bert.bert_encoder.layers.1.attention.attention.key_layer.bias',
 'bert.bert.bert_encoder.layers.1.attention.attention.value_layer.weight',
 'bert.bert.bert_encoder.layers.1.attention.attention.value_layer.bias',
 'bert.bert.bert_encoder.layers.2.attention.attention.query_layer.weight',
 'bert.bert.bert_encoder.layers.2.attention.a

In [10]:
class myMetric(Metric):
    '''
    Self-defined Metric as a callback.
    '''
    def __init__(self):
        super(myMetric, self).__init__()
        self.clear()

    def clear(self):
        self.total_num = 0
        self.acc_num = 0

    def update(self, *inputs):
        total_num = self._convert_data(inputs[0])
        acc_num = self._convert_data(inputs[1])
        self.total_num = total_num
        self.acc_num = acc_num

    def eval(self):
        return self.acc_num/self.total_num


class GetLogProbs(nn.Cell):
    '''
    Get MaskedLM prediction scores
    '''
    def __init__(self, config):
        super(GetLogProbs, self).__init__()
        self.bert = BertModel(config, False)
        self.cls1 = GetMaskedLMOutput(config)

    def construct(self, input_ids, input_mask, token_type_id, masked_pos):
        sequence_output, _, embedding_table = self.bert(input_ids, token_type_id, input_mask)
        prediction_scores = self.cls1(sequence_output, embedding_table, masked_pos)
        return prediction_scores


class BertPretrainEva(nn.Cell):
    '''
    Evaluate MaskedLM prediction scores
    '''
    def __init__(self, config):
        super(BertPretrainEva, self).__init__()
        self.bert = GetLogProbs(config)
        self.argmax = P.Argmax(axis=-1, output_type=mstype.int32)
        self.equal = P.Equal()
        self.mean = P.ReduceMean()
        self.sum = P.ReduceSum()
        self.total = Parameter(Tensor([0], mstype.float32))
        self.acc = Parameter(Tensor([0], mstype.float32))
        self.reshape = P.Reshape()
        self.shape = P.Shape()
        self.cast = P.Cast()


    def construct(self, input_ids, input_mask, token_type_id, masked_pos, masked_ids, masked_weights, nsp_label):
        """Calculate prediction scores"""
        bs, _ = self.shape(input_ids)
        probs = self.bert(input_ids, input_mask, token_type_id, masked_pos)
        index = self.argmax(probs)
        index = self.reshape(index, (bs, -1))
        eval_acc = self.equal(index, masked_ids)
        eval_acc1 = self.cast(eval_acc, mstype.float32)
        real_acc = eval_acc1 * masked_weights
        acc = self.sum(real_acc)
        total = self.sum(masked_weights)
        self.total += total
        self.acc += acc
        return acc, self.total, self.acc


def get_enwiki_512_dataset(batch_size=1, repeat_count=1, distribute_file=''):
    '''
    Get enwiki dataset when seq_length is 512.
    '''
    ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask", "segment_ids",
                                                                            "masked_lm_positions", "masked_lm_ids",
                                                                            "masked_lm_weights",
                                                                            "next_sentence_labels"])
    type_cast_op = C.TypeCast(mstype.int32)
    ds = ds.map(operations=type_cast_op, input_columns="segment_ids")
    ds = ds.map(operations=type_cast_op, input_columns="input_mask")
    ds = ds.map(operations=type_cast_op, input_columns="input_ids")
    ds = ds.map(operations=type_cast_op, input_columns="masked_lm_ids")
    ds = ds.map(operations=type_cast_op, input_columns="masked_lm_positions")
    ds = ds.map(operations=type_cast_op, input_columns="next_sentence_labels")
    ds = ds.repeat(repeat_count)

    # apply batch operations
    ds = ds.batch(batch_size, drop_remainder=True)
    return ds


def bert_predict():
    '''
    Predict function
    '''
    devid = int(0)
    context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=devid)
    net_for_pretraining = BertPretrainEva(bert_net_cfg)
    net_for_pretraining.set_train(False)
    param_dict = load_checkpoint("2.ckpt")
    load_param_into_net(net_for_pretraining, param_dict)
    model = Model(net_for_pretraining)
    return model, net_for_pretraining


def MLM_eval():
    '''
    Evaluate function
    '''
    _,  net_for_pretraining = bert_predict()
    net = Model(net_for_pretraining, eval_network=net_for_pretraining, eval_indexes=[0, 1, 2],
                metrics={'name': myMetric()})
    res = net.eval(dataset, dataset_sink_mode=False)
    print("==============================================================")
    for _, v in res.items():
        print("Accuracy is: ")
        print(v)
    print("==============================================================")



In [38]:
import mindspore.dataset.text as text

input_tx = "[CLS] [MASK] [MASK] [MASK] 是中国神魔小说的经典之作，与《三国演义》《水浒传》《红楼梦》并称为中国古典四大名著。[SEP]"

vof=open('vocab.txt',encoding='utf-8')
vocab = text.Vocab.from_list([i for i in vof.readline().strip()])


In [39]:
tokenizer_op = text.BertTokenizer(vocab=vocab)


In [44]:
import mindspore.dataset as ds
dataset = ds.TextFileDataset('input.txt', shuffle=False)

In [45]:
for data in dataset.create_dict_iterator(output_numpy=True):
    print(text.to_str(data['text']))

[CLS] [MASK] [MASK] [MASK] 是中国神魔小说的经典之作，与《三国演义》《水浒传》《红楼梦》并称为中国古典四大名著。[SEP]


In [46]:
dataset = dataset.map(operations=tokenizer_op)

print("------------------------after tokenization-----------------------------")

for i in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
    print(text.to_str(i['text']))

------------------------after tokenization-----------------------------
['[UNK]' '[UNK]' '[UNK]' '[UNK]' '[UNK]' '[UNK]' '[UNK]' '[UNK]' '[UNK]'
 '[UNK]' '[UNK]' '[UNK]' '[UNK]' '[UNK]' '[UNK]' '[UNK]' '[UNK]' '[UNK]'
 '[UNK]' '[UNK]' '[UNK]' '[UNK]' '[UNK]' '[UNK]' '[UNK]' '[UNK]' '[UNK]'
 '[UNK]' '[UNK]' '[UNK]' '[UNK]' '[UNK]' '[UNK]' '[UNK]' '[UNK]' '[UNK]'
 '[UNK]' '[UNK]' '[UNK]' '[UNK]' '[UNK]' '[UNK]' '[UNK]' '[UNK]' '[UNK]'
 '[UNK]' '[UNK]']
