先获取数据

In [1]:
import json

with open("./config/conll03.json", "r", encoding="utf-8") as f:
    config_data = json.load(f)
    
with open("./data/{}/train.json".format(config_data['dataset']), "r", encoding="utf-8") as f:
    train_data = json.load(f)

In [2]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(config_data['bert_name'], cache_dir="./cache")

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from data_loader import Vocabulary
from data_loader import fill_vocab

vocab = Vocabulary()
entity_num = fill_vocab(vocab, train_data)
print(entity_num)   # 实体数目

29441


标签和索引的对应关系

In [4]:
vocab.id2label

{0: '<pad>', 1: '<suc>', 2: 'org', 3: 'misc', 4: 'per', 5: 'loc'}

如何从原始输入得到模型输入数据

In [5]:
from data_loader import RelationDataset
from data_loader import process_bert

train_dataset = RelationDataset(*process_bert(train_data, tokenizer, vocab))

In [6]:
from torch.utils.data import DataLoader
import data_loader

train_loader = DataLoader(train_dataset, batch_size=config_data['batch_size'], shuffle=True, num_workers=4, drop_last=True)

模型部分

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from model import LayerNorm
from model import ConvolutionLayer
from model import Biaffine
from model import MLP
from model import CoPredictor
from transformers import AutoModel

class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        self.use_bert_last_4_layers = config.use_bert_last_4_layers

        self.lstm_hid_size = config.lstm_hid_size
        self.conv_hid_size = config.conv_hid_size

        lstm_input_size = 0

        self.bert = AutoModel.from_pretrained(config.bert_name, cache_dir="./cache/", output_hidden_states=True)
        lstm_input_size += config.bert_hid_size

        self.dis_embs = nn.Embedding(20, config.dist_emb_size)
        self.reg_embs = nn.Embedding(3, config.type_emb_size)

        self.encoder = nn.LSTM(lstm_input_size, config.lstm_hid_size // 2, num_layers=1, batch_first=True,
                               bidirectional=True)

        conv_input_size = config.lstm_hid_size + config.dist_emb_size + config.type_emb_size

        self.convLayer = ConvolutionLayer(conv_input_size, config.conv_hid_size, config.dilation, config.conv_dropout)
        self.dropout = nn.Dropout(config.emb_dropout)
        self.predictor = CoPredictor(config.label_num, config.lstm_hid_size, config.biaffine_size,
                                     config.conv_hid_size * len(config.dilation), config.ffnn_hid_size,
                                     config.out_dropout)

        self.cln = LayerNorm(config.lstm_hid_size, config.lstm_hid_size, conditional=True)

    def forward(self, bert_inputs, grid_mask2d, dist_inputs, pieces2word, sent_length):
        '''
        :param bert_inputs: [B, L'], L': num of subwords + 2
        :param grid_mask2d: [B, L, L], L: num of tokens
        :param dist_inputs: [B, L, L], distance between tokens
        :param pieces2word: [B, L, L'], token和subword的映射
        :param sent_length: [B]
        :return:
        '''
        bert_embs = self.bert(input_ids=bert_inputs, attention_mask=bert_inputs.ne(0).float())
        if self.use_bert_last_4_layers:
            bert_embs = torch.stack(bert_embs[2][-4:], dim=-1).mean(-1)
        else:
            bert_embs = bert_embs[0]

        length = pieces2word.size(1)

        min_value = torch.min(bert_embs).item()

        # Max pooling word representations from pieces
        _bert_embs = bert_embs.unsqueeze(1).expand(-1, length, -1, -1)
        _bert_embs = torch.masked_fill(_bert_embs, pieces2word.eq(0).unsqueeze(-1), min_value)
        word_reps, _ = torch.max(_bert_embs, dim=2)

        word_reps = self.dropout(word_reps)
        packed_embs = pack_padded_sequence(word_reps, sent_length.cpu(), batch_first=True, enforce_sorted=False)
        packed_outs, (hidden, _) = self.encoder(packed_embs)
        word_reps, _ = pad_packed_sequence(packed_outs, batch_first=True, total_length=sent_length.max())

        cln = self.cln(word_reps.unsqueeze(2), word_reps)

        dis_emb = self.dis_embs(dist_inputs)
        tril_mask = torch.tril(grid_mask2d.clone().long())
        reg_inputs = tril_mask + grid_mask2d.clone().long()
        reg_emb = self.reg_embs(reg_inputs)

        conv_inputs = torch.cat([dis_emb, reg_emb, cln], dim=-1)
        conv_inputs = torch.masked_fill(conv_inputs, grid_mask2d.eq(0).unsqueeze(-1), 0.0)
        conv_outputs = self.convLayer(conv_inputs)
        conv_outputs = torch.masked_fill(conv_outputs, grid_mask2d.eq(0).unsqueeze(-1), 0.0)
        outputs = self.predictor(word_reps, word_reps, conv_outputs)

        return outputs

In [8]:
vocab.id2label

{0: '<pad>', 1: '<suc>', 2: 'org', 3: 'misc', 4: 'per', 5: 'loc'}

In [9]:
config_data['label_num'] = len(vocab.id2label)

数据传入模型

In [10]:
for i, data_batch in enumerate(train_loader):
    break

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "d:\Software\Anaconda\anaconda\envs\w2ner\lib\site-packages\torch\utils\data\_utils\worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "d:\Software\Anaconda\anaconda\envs\w2ner\lib\site-packages\torch\utils\data\_utils\fetch.py", line 52, in fetch
    return self.collate_fn(data)
  File "d:\Software\Anaconda\anaconda\envs\w2ner\lib\site-packages\torch\utils\data\_utils\collate.py", line 84, in default_collate
    return [default_collate(samples) for samples in transposed]
  File "d:\Software\Anaconda\anaconda\envs\w2ner\lib\site-packages\torch\utils\data\_utils\collate.py", line 84, in <listcomp>
    return [default_collate(samples) for samples in transposed]
  File "d:\Software\Anaconda\anaconda\envs\w2ner\lib\site-packages\torch\utils\data\_utils\collate.py", line 56, in default_collate
    return torch.stack(batch, 0, out=out)
RuntimeError: stack expects each tensor to be equal size, but got [11] at entry 0 and [12] at entry 1


In [None]:
model = Model(config_data)