就拿`conll03`数据集测试

In [8]:
class Config:
    def __init__(self, config_data):
        config = config_data

        self.dataset = config["dataset"]

        self.dist_emb_size = config["dist_emb_size"]
        self.type_emb_size = config["type_emb_size"]
        self.lstm_hid_size = config["lstm_hid_size"]
        self.conv_hid_size = config["conv_hid_size"]
        self.bert_hid_size = config["bert_hid_size"]
        self.biaffine_size = config["biaffine_size"]
        self.ffnn_hid_size = config["ffnn_hid_size"]

        self.dilation = config["dilation"]

        self.emb_dropout = config["emb_dropout"]
        self.conv_dropout = config["conv_dropout"]
        self.out_dropout = config["out_dropout"]

        self.epochs = config["epochs"]
        self.batch_size = config["batch_size"]

        self.learning_rate = config["learning_rate"]
        self.weight_decay = config["weight_decay"]
        self.clip_grad_norm = config["clip_grad_norm"]
        self.bert_name = config["bert_name"]
        self.bert_learning_rate = config["bert_learning_rate"]
        self.warm_factor = config["warm_factor"]

        self.use_bert_last_4_layers = config["use_bert_last_4_layers"]

        self.seed = config["seed"]


    def __repr__(self):
        return "{}".format(self.__dict__.items())

读取数据

In [9]:
import json
from transformers import AutoTokenizer
import data_loader

with open("./config/conll03.json", "r", encoding="utf-8") as f:
    config_data = json.load(f)

# with open("./data/{}/train.json".format(config_data['dataset']), "r", encoding="utf-8") as f:
#     train_data = json.load(f)

# with open("./data/{}/dev.json".format(config_data['dataset']), "r", encoding="utf-8") as f:
#     train_data = json.load(f)

# with open("./data/{}/test.json".format(config_data['dataset']), "r", encoding="utf-8") as f:
#     test_data = json.load(f)

tokenizer = AutoTokenizer.from_pretrained(config_data['bert_name'], cache_dir="./cache")

In [10]:
config = Config(config_data)

In [11]:
import utils

logger = utils.get_logger(config.dataset)
logger.info(config)
config.logger = logger
datasets = data_loader.load_data_bert(config)[0]

2022-05-12 10:22:24 - INFO: dict_items([('dataset', 'conll03'), ('dist_emb_size', 20), ('type_emb_size', 20), ('lstm_hid_size', 768), ('conv_hid_size', 96), ('bert_hid_size', 1024), ('biaffine_size', 768), ('ffnn_hid_size', 128), ('dilation', [1, 2, 3]), ('emb_dropout', 0.5), ('conv_dropout', 0.5), ('out_dropout', 0.33), ('epochs', 10), ('batch_size', 12), ('learning_rate', 0.001), ('weight_decay', 0), ('clip_grad_norm', 1.0), ('bert_name', 'bert-large-cased'), ('bert_learning_rate', 1e-05), ('warm_factor', 0.1), ('use_bert_last_4_layers', True), ('seed', 123)])
2022-05-12 10:22:24 - INFO: dict_items([('dataset', 'conll03'), ('dist_emb_size', 20), ('type_emb_size', 20), ('lstm_hid_size', 768), ('conv_hid_size', 96), ('bert_hid_size', 1024), ('biaffine_size', 768), ('ffnn_hid_size', 128), ('dilation', [1, 2, 3]), ('emb_dropout', 0.5), ('conv_dropout', 0.5), ('out_dropout', 0.33), ('epochs', 10), ('batch_size', 12), ('learning_rate', 0.001), ('weight_decay', 0), ('clip_grad_norm', 1.0), 

In [12]:
len(datasets)

3

In [13]:
from torch.utils.data import DataLoader

train_loader, dev_loader, test_loader = (
    DataLoader(dataset=dataset,
                batch_size=config.batch_size,
                collate_fn=data_loader.collate_fn,
                shuffle=i == 0,
                num_workers=4,
                drop_last=i == 0)
    for i, dataset in enumerate(datasets)
)

关于`dataset`中的`tensor`: `bert_inputs`, `grid_labels`, `grid_mask2d`, `pieces2word`, `dist_inputs`, `sent_length`, `entity_text`

模型部分

In [14]:
from model import Model

model = Model(config)

Some weights of the model checkpoint at bert-large-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [15]:
import torch
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '7'

model_path = "./model.pt"
model.load_state_dict(torch.load(model_path))

<All keys matched successfully>

评估模型

In [16]:
import utils
# 评估模式
model.eval()
model.cuda()

pred_result = []
label_result = []

with torch.no_grad():
    for i, data_batch in enumerate(test_loader):
        entity_text = data_batch[-1]
        label_result.append(entity_text)
        data_batch = [data.cuda() for data in data_batch[:-1]]
        bert_inputs, grid_labels, grid_mask2d, pieces2word, dist_inputs, sent_length = data_batch

        outputs = model(bert_inputs, grid_mask2d, dist_inputs, pieces2word, sent_length)
        outputs = torch.argmax(outputs, -1)
        predictions = utils.get_predictions(outputs.cpu().numpy(), entity_text, sent_length.cpu().numpy())
        pred_result.append(predictions)
        # label_result.append(grid_labels)
        # pred_result.append(outputs)

In [17]:
labels = []
preds = []
for i in range(len(pred_result)):
    labels.extend(label_result[i])
    preds.extend(pred_result[i])

In [18]:
labels[: 5]

[{'2-#-5', '7-#-4'},
 {'0-1-#-4'},
 {'0-#-5', '2-3-4-#-5'},
 {'0-#-5', '15-#-5', '6-7-#-3'},
 {'1-#-5', '23-#-5'}]

In [19]:
def get_entities(sentence, entity_text):
    """
    根据entity_text和原文得到真正的实体
    """
    entities = []
    for entity in entity_text:
        ids = entity.split("-")
        ids = ids[: -2]
        entity_ids = [int(x) for x in ids]
        entity = [sentence[idx] for idx in entity_ids]
        entities.append(" ".join(entity))
    return ",".join(entities)

In [14]:
def isContinuous(entity):
    """
    该实体是否连续
    """
    continuous = True
    ids = entity.split("-")
    ids = ids[: -2]
    ids = [int(x) for x in ids]
    start, end = ids[0], ids[-1]
    if list(range(start, end + 1)) != ids:
        continuous = False
    return continuous

一个示例

In [20]:
get_entities(test_data[0]['sentence'], labels[0])

''

写入预测文件

In [None]:
with open("./predictions/share_2013_pred.txt", "w", encoding="utf-8") as f:
    for i in range(len(test_data)):
        f.write("raw sentence: \n")
        f.write(" ".join(test_data[i]['sentence']) + "\n")
        f.write("true entities: \n")
        f.write(get_entities(test_data[i]['sentence'], labels[i]) + "\n")
        f.write("predicted entities: \n")
        f.write(get_entities(test_data[i]['sentence'], preds[i]) + "\n")
        f.write("\n")

分析预测文件

先看整体表现和非连续实体表现

In [21]:
# with open("./predictions/share_2013_pred.txt", "r", encoding="utf-8") as f:
#     lines = f.readlines()

num_entities = 0    # 实体数量
num_correct_entities = 0     # 正确实体数量
num_preds = 0    # 预测实体数量

num_dis_entities = 0    # 非连续实体数量
num_correct_dis_entities = 0    # 正确非连续实体数量
num_preds_dis = 0    # 预测非连续实体数量

for i in range(len(test_data)):
    true_entity = labels[i]
    num_entities += len(true_entity)

    pred_entity = preds[i]
    num_preds += len(pred_entity)
    
    for entity in true_entity:
        if entity in pred_entity:
            num_correct_entities += 1
            if isContinuous(entity) == False:
                num_correct_dis_entities += 1
        if isContinuous(entity) == False:
            num_dis_entities += 1
    
    for entity in pred_entity:
        if isContinuous(entity) == False:
            num_preds_dis += 1


print("num_entities: ", num_entities)
print("num_correct_entities: ", num_correct_entities)
print("num_preds: ", num_preds)
print("num_dis_entities: ", num_dis_entities)
print("num_correct_dis_entities: ", num_correct_dis_entities)
print("num_preds_dis: ", num_preds_dis)

num_entities:  5333
num_correct_entities:  4163
num_preds:  5111
num_dis_entities:  436
num_correct_dis_entities:  219
num_preds_dis:  344


In [29]:
from statistics import get_relation_matrix
from statistics import count_relation

count_relation(get_relation_matrix(test_data[12], labels[12]))

In [None]:
import dis_utils

# 得到含有非连续实体的句子的索引
dis_indexes = []

# 得到不含有实体的句子的索引
no_entity_indexes = []

# 得到只含有连续实体的句子的索引
entity_indexes = []

for i, sample in enumerate(test_data):
    sentence = sample["sentence"]
    ner = sample["ner"]
    dis_ner = []
    if not ner:
        no_entity_indexes.append(i)
    else:
        for item in ner:
            indexes = item['index']

            if not dis_utils.isContinuous(indexes):
                dis_ner.append(item)
        if not dis_ner:
            entity_indexes.append(i)
        else:
            dis_indexes.append(i)

In [23]:
with open("./predictions/share_2013_pred.txt", "r", encoding="utf-8") as f:
    lines = f.readlines()

In [None]:
no_entity_error_lines = []
dis_error_lines = []
entity_error_lines = []
for i in range(len(lines)):
    if i % 7 == 0:
        true_entity = lines[i + 3].strip()
        pred_entity = lines[i + 5].strip()
        if sorted(true_entity.split(",")) != sorted(pred_entity.split(",")):
            index = int(i / 7)
            if index in no_entity_indexes:
                no_entity_error_lines.extend(lines[i: i + 7])
            elif index in dis_indexes:
                dis_error_lines.extend(lines[i: i + 7])
            elif index in entity_indexes:
                entity_error_lines.extend(lines[i: i + 7])

In [None]:
with open("./predictions/share_2013_error/no_entity_error.txt", "w", encoding="utf-8") as f:
    for line in no_entity_error_lines:
        f.write(line)

In [None]:
with open("./predictions/share_2013_error/dis_entity_error.txt", "w", encoding="utf-8") as f:
    for line in dis_error_lines:
        f.write(line)

In [None]:
with open("./predictions/share_2013_error/entity_error.txt", "w", encoding="utf-8") as f:
    for line in entity_error_lines:
        f.write(line)

In [24]:
num_predict_no_entity = 0   # 预测为空实体的句子数
num_no_entity = 0   # 实体为空的句子数
num_sentences = 0   # 总句子数

for i in range(len(lines)):
    if i % 7 == 0:
        num_sentences += 1
        true_entity = lines[i + 3].strip()
        pred_entity = lines[i + 5].strip()
        if true_entity == "":
            num_no_entity += 1
        if pred_entity == "":
            num_predict_no_entity += 1
print(f"实体为空的句子数：{num_no_entity}")
print("占整个测试集的比例：{}".format(num_no_entity / num_sentences))
print(f"预测为空实体的句子数：{num_predict_no_entity}")
print("占整个测试集的比例：{}".format(num_predict_no_entity / num_sentences))

实体为空的句子数：5910
占整个测试集的比例：0.656010656010656
预测为空实体的句子数：6002
占整个测试集的比例：0.6662226662226662


In [25]:
num_entities = 0   # 实体数目
num_pred_entities = 0   # 预测实体数目

for i in range(len(lines)):
    if i % 7 == 0:
        true_entity = lines[i + 3].strip()
        num_entities += len(true_entity.split(","))

        pred_entity = lines[i + 5].strip()
        num_pred_entities += len(pred_entity.split(","))
print(f"实体数目：{num_entities}")
print("预测实体数目：{}".format(num_pred_entities))

实体数目：11244
预测实体数目：11114


统计关系数目

In [38]:
num = 0
num0 = 0
num1 = 0
num2 = 0
pre_num = 0
pre_num0 = 0
pre_num1 = 0
pre_num2 = 0

for i in range(len(test_data)):
    sample = test_data[i]
    label = labels[i]
    pred = preds[i]

    truth_matrix = get_relation_matrix(sample, label)
    pred_matrix = get_relation_matrix(sample, pred)

    a, b, c, d = count_relation(truth_matrix)
    num0 += a
    num1 += b
    num2 += c
    num += d

    a1, b1, c1, d1 = count_relation(pred_matrix)
    pre_num0 += a1
    pre_num1 += b1
    pre_num2 += c1
    pre_num += d1


print(f'一共有{num}个关系')
print(f'一共有{num0}个关系为0')
print(f'一共有{num1}个关系为1')
print(f'一共有{num2}个关系为2')
print(f'一共预测{pre_num}个关系')
print(f'一共预测{pre_num0}个关系为0')
print(f'一共预测{pre_num1}个关系为1')
print(f'一共预测{pre_num2}个关系为2')

一共有3719483个关系
一共有3710320个关系为0
一共有3834个关系为1
一共有5329个关系为2
一共预测3719483个关系
一共预测3710420个关系为0
一共预测3964个关系为1
一共预测5099个关系为2
