## 训练部分

先获取数据

In [1]:
import json

with open("./config/conll03.json", "r", encoding="utf-8") as f:
    config_data = json.load(f)
    
with open("./data/{}/train.json".format(config_data['dataset']), "r", encoding="utf-8") as f:
    train_data = json.load(f)
    
with open("./data/{}/test.json".format(config_data['dataset']), "r", encoding="utf-8") as f:
    test_data = json.load(f)

In [2]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(config_data['bert_name'], cache_dir="./cache")

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from data_loader import Vocabulary
from data_loader import fill_vocab

vocab = Vocabulary()
entity_num = fill_vocab(vocab, train_data)
print(entity_num)   # 实体数目

29441


标签和索引的对应关系

In [4]:
vocab.id2label

{0: '<pad>', 1: '<suc>', 2: 'org', 3: 'misc', 4: 'per', 5: 'loc'}

如何从原始输入得到模型输入数据

In [5]:
from data_loader import RelationDataset
from data_loader import process_bert

train_dataset = RelationDataset(*process_bert(train_data, tokenizer, vocab))

In [6]:
from torch.utils.data import DataLoader
import data_loader

train_loader = DataLoader(train_dataset, batch_size=config_data['batch_size'], shuffle=True, num_workers=4, drop_last=True)

模型部分

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from model import LayerNorm
from model import ConvolutionLayer
from model import Biaffine
from model import MLP
from model import CoPredictor
from transformers import AutoModel

class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        self.use_bert_last_4_layers = config.use_bert_last_4_layers

        self.lstm_hid_size = config.lstm_hid_size
        self.conv_hid_size = config.conv_hid_size

        lstm_input_size = 0

        self.bert = AutoModel.from_pretrained(config.bert_name, cache_dir="./cache/", output_hidden_states=True)
        lstm_input_size += config.bert_hid_size

        self.dis_embs = nn.Embedding(20, config.dist_emb_size)
        self.reg_embs = nn.Embedding(3, config.type_emb_size)

        self.encoder = nn.LSTM(lstm_input_size, config.lstm_hid_size // 2, num_layers=1, batch_first=True,
                               bidirectional=True)

        conv_input_size = config.lstm_hid_size + config.dist_emb_size + config.type_emb_size

        self.convLayer = ConvolutionLayer(conv_input_size, config.conv_hid_size, config.dilation, config.conv_dropout)
        self.dropout = nn.Dropout(config.emb_dropout)
        self.predictor = CoPredictor(config.label_num, config.lstm_hid_size, config.biaffine_size,
                                     config.conv_hid_size * len(config.dilation), config.ffnn_hid_size,
                                     config.out_dropout)

        self.cln = LayerNorm(config.lstm_hid_size, config.lstm_hid_size, conditional=True)

    def forward(self, bert_inputs, grid_mask2d, dist_inputs, pieces2word, sent_length):
        '''
        :param bert_inputs: [B, L'], L': num of subwords + 2
        :param grid_mask2d: [B, L, L], L: num of tokens
        :param dist_inputs: [B, L, L], distance between tokens
        :param pieces2word: [B, L, L'], token和subword的映射
        :param sent_length: [B]
        :return:
        '''
        bert_embs = self.bert(input_ids=bert_inputs, attention_mask=bert_inputs.ne(0).float())
        if self.use_bert_last_4_layers:
            bert_embs = torch.stack(bert_embs[2][-4:], dim=-1).mean(-1)
        else:
            bert_embs = bert_embs[0]

        length = pieces2word.size(1)

        min_value = torch.min(bert_embs).item()

        # Max pooling word representations from pieces
        _bert_embs = bert_embs.unsqueeze(1).expand(-1, length, -1, -1)
        _bert_embs = torch.masked_fill(_bert_embs, pieces2word.eq(0).unsqueeze(-1), min_value)
        word_reps, _ = torch.max(_bert_embs, dim=2)

        word_reps = self.dropout(word_reps)
        packed_embs = pack_padded_sequence(word_reps, sent_length.cpu(), batch_first=True, enforce_sorted=False)
        packed_outs, (hidden, _) = self.encoder(packed_embs)
        word_reps, _ = pad_packed_sequence(packed_outs, batch_first=True, total_length=sent_length.max())

        cln = self.cln(word_reps.unsqueeze(2), word_reps)

        dis_emb = self.dis_embs(dist_inputs)
        tril_mask = torch.tril(grid_mask2d.clone().long())
        reg_inputs = tril_mask + grid_mask2d.clone().long()
        reg_emb = self.reg_embs(reg_inputs)

        conv_inputs = torch.cat([dis_emb, reg_emb, cln], dim=-1)
        conv_inputs = torch.masked_fill(conv_inputs, grid_mask2d.eq(0).unsqueeze(-1), 0.0)
        conv_outputs = self.convLayer(conv_inputs)
        conv_outputs = torch.masked_fill(conv_outputs, grid_mask2d.eq(0).unsqueeze(-1), 0.0)
        outputs = self.predictor(word_reps, word_reps, conv_outputs)

        return outputs

In [8]:
vocab.id2label

{0: '<pad>', 1: '<suc>', 2: 'org', 3: 'misc', 4: 'per', 5: 'loc'}

In [9]:
config_data['label_num'] = len(vocab.id2label)

数据传入模型

In [10]:
# for i, data_batch in enumerate(train_loader):
#     break

In [11]:
# model = Model(config_data)

## 解码部分

### 拿到数据

In [12]:
class Config:
    def __init__(self, config_data):
        config = config_data

        self.dataset = config["dataset"]

        self.dist_emb_size = config["dist_emb_size"]
        self.type_emb_size = config["type_emb_size"]
        self.lstm_hid_size = config["lstm_hid_size"]
        self.conv_hid_size = config["conv_hid_size"]
        self.bert_hid_size = config["bert_hid_size"]
        self.biaffine_size = config["biaffine_size"]
        self.ffnn_hid_size = config["ffnn_hid_size"]

        self.dilation = config["dilation"]

        self.emb_dropout = config["emb_dropout"]
        self.conv_dropout = config["conv_dropout"]
        self.out_dropout = config["out_dropout"]

        self.epochs = config["epochs"]
        self.batch_size = config["batch_size"]

        self.learning_rate = config["learning_rate"]
        self.weight_decay = config["weight_decay"]
        self.clip_grad_norm = config["clip_grad_norm"]
        self.bert_name = config["bert_name"]
        self.bert_learning_rate = config["bert_learning_rate"]
        self.warm_factor = config["warm_factor"]

        self.use_bert_last_4_layers = config["use_bert_last_4_layers"]

        self.seed = config["seed"]


    def __repr__(self):
        return "{}".format(self.__dict__.items())
    

import json
from transformers import AutoTokenizer
import data_loader

with open("./config/conll03.json", "r", encoding="utf-8") as f:
    config_data = json.load(f)

tokenizer = AutoTokenizer.from_pretrained(config_data['bert_name'], cache_dir="./cache")
config = Config(config_data)

import utils

logger = utils.get_logger(config.dataset)
logger.info(config)
config.logger = logger
datasets = data_loader.load_data_bert(config)[0]


from torch.utils.data import DataLoader

train_loader, dev_loader, test_loader = (
    DataLoader(dataset=dataset,
                batch_size=config.batch_size,
                collate_fn=data_loader.collate_fn,
                shuffle=i == 0,
                num_workers=4,
                drop_last=i == 0)
    for i, dataset in enumerate(datasets)
)

2022-05-12 14:12:26 - INFO: dict_items([('dataset', 'conll03'), ('dist_emb_size', 20), ('type_emb_size', 20), ('lstm_hid_size', 768), ('conv_hid_size', 96), ('bert_hid_size', 1024), ('biaffine_size', 768), ('ffnn_hid_size', 128), ('dilation', [1, 2, 3]), ('emb_dropout', 0.5), ('conv_dropout', 0.5), ('out_dropout', 0.33), ('epochs', 10), ('batch_size', 2), ('learning_rate', 0.001), ('weight_decay', 0), ('clip_grad_norm', 1.0), ('bert_name', 'bert-large-cased'), ('bert_learning_rate', 1e-05), ('warm_factor', 0.1), ('use_bert_last_4_layers', True), ('seed', 123)])
2022-05-12 14:12:36 - INFO: 
+---------+-----------+----------+
| conll03 | sentences | entities |
+---------+-----------+----------+
|  train  |   17291   |  29441   |
|   dev   |    3453   |   5648   |
|   test  |    3453   |   5648   |
+---------+-----------+----------+


### 拿到训练好的模型

#### 先实例化模型

In [13]:
from model import Model

model = Model(config)

Some weights of the model checkpoint at bert-large-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


#### 再传入模型参数

In [14]:
import torch
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '7'

model_path = "./model.pt"
model.load_state_dict(torch.load(model_path))

<All keys matched successfully>

### 预测（解码）

先取一个Batch的数据

In [15]:
for i, data_batch in enumerate(test_loader):
    break

In [16]:
# 标签和索引的对应关系
vocab.id2label

{0: '<pad>', 1: '<suc>', 2: 'org', 3: 'misc', 4: 'per', 5: 'loc'}

In [17]:
# 实体对应的字符串
entity_text = data_batch[-1]    # [{'2-#-5', '7-#-4'}, {'0-1-#-4'}]

# 模型输入
data_batch = [data.cuda() for data in data_batch[:-1]]
bert_inputs, grid_labels, grid_mask2d, pieces2word, dist_inputs, sent_length = data_batch

In [18]:
# bert_inputs.shape   # torch.Size([2, 33])
# grid_labels.shape   # torch.Size([2, 12, 12])
# grid_mask2d.shape   # torch.Size([2, 12, 12])
# pieces2word.shape   # torch.Size([2, 12, 33])
# sent_length.shape   # torch.Size([2])
# 对这个Batch来说，token length是12，subword length是33

import utils
# 评估模式
model.eval()
model.cuda()

outputs = model(bert_inputs, grid_mask2d, dist_inputs, pieces2word, sent_length)

In [19]:
# outputs是预测Token之间的关系
# outputs.shape   # torch.Size([2, 12, 12, 6])

In [20]:
arg_outputs = torch.argmax(outputs, -1)
predictions = utils.get_predictions(arg_outputs.cpu().numpy(), sent_length.cpu().numpy())

In [21]:
from estimate_entity_prob import estimate_entity_prob

entity_and_prob = estimate_entity_prob(model_outputs=outputs, predictions=predictions)

  tag_prob = F.softmax(logits)[tag]


In [22]:
entity_and_prob

[[('2-#-5', 0.9996621608734131), ('7-#-5', 0.9658114910125732)],
 [('0-1-#-4', 0.9999997019767761)]]

对测试数据进行预测

In [24]:
import utils
from estimate_entity_prob import estimate_entity_prob

# 评估模式
model.eval()
model.cuda()

pred_result = []
label_result = []

with torch.no_grad():
    for i, data_batch in enumerate(test_loader):
        entity_text = data_batch[-1]
        label_result.append(entity_text)
        data_batch = [data.cuda() for data in data_batch[:-1]]
        bert_inputs, grid_labels, grid_mask2d, pieces2word, dist_inputs, sent_length = data_batch

        outputs = model(bert_inputs, grid_mask2d, dist_inputs, pieces2word, sent_length)
        arg_outputs = torch.argmax(outputs, -1)
        predictions = utils.get_predictions(arg_outputs.cpu().numpy(), sent_length.cpu().numpy())
        
        entity_and_prob = estimate_entity_prob(model_outputs=outputs, predictions=predictions)
        pred_result.append(entity_and_prob)
        

  tag_prob = F.softmax(logits)[tag]


In [25]:
labels = []
preds = []
for i in range(len(pred_result)):
    labels.extend(label_result[i])
    preds.extend(pred_result[i])

In [26]:
labels[: 5]

[{'2-#-5', '7-#-4'},
 {'0-1-#-4'},
 {'0-#-5', '2-3-4-#-5'},
 {'0-#-5', '15-#-5', '6-7-#-3'},
 {'1-#-5', '23-#-5'}]

In [36]:
vocab.id2label

{0: '<pad>', 1: '<suc>', 2: 'org', 3: 'misc', 4: 'per', 5: 'loc'}

In [27]:
preds[: 5]

[[('2-#-5', 0.9996621608734131), ('7-#-5', 0.9658114910125732)],
 [('0-1-#-4', 0.9999997019767761)],
 [('0-#-5', 0.9995947480201721), ('2-3-4-#-5', 0.9999992847442627)],
 [('0-#-5', 0.9999983310699463),
  ('15-#-5', 0.9999908208847046),
  ('6-7-#-3', 0.9988368153572083)],
 [('1-#-5', 0.9999998807907104), ('23-#-5', 0.9999984502792358)]]

In [42]:
def get_entities(sentence, entity_text, id2label):
    """
    根据entity_text和原文得到真正的实体
    """
    entities = []
    for entity in entity_text:
        ids = entity.split("-")
        tag_index = int(ids[-1])
        
        ids = ids[: -2]
        entity_ids = [int(x) for x in ids]
        
        entity = [sentence[idx] for idx in entity_ids]
        entities.append((" ".join(entity), id2label[tag_index]))
    return entities

In [43]:
get_entities(test_data[0]['sentence'], labels[0], vocab.id2label)

[('JAPAN', 'loc'), ('CHINA', 'per')]

In [45]:
pred = [item[0] for item in preds[0]]
pred

['2-#-5', '7-#-5']

In [47]:
prob = [item[1] for item in preds[0]]
prob

[0.9996621608734131, 0.9658114910125732]

In [46]:
get_entities(test_data[0]['sentence'], pred, vocab.id2label)

[('JAPAN', 'loc'), ('CHINA', 'loc')]

写入预测文件

In [50]:
with open("./predictions/conll03_pred.txt", "w", encoding="utf-8") as f:
    for i in range(len(test_data)):
        f.write("raw sentence: \n")
        f.write(" ".join(test_data[i]['sentence']) + "\n")
        f.write("true entities: \n")
        f.write(str(get_entities(test_data[i]['sentence'], labels[i], vocab.id2label)) + "\n")
        f.write("predicted entities: \n")
        pred = [item[0] for item in preds[i]]
        f.write(str(get_entities(test_data[i]['sentence'], pred, vocab.id2label)) + "\n")
        f.write("predicted entities prob: \n")
        prob = [item[1] for item in preds[i]]
        f.write(str(prob) + "\n")
        f.write("\n")