In [1]:
import torch
import tqdm
import json
import argparse
import numpy as np
import torch.optim as optim
from collections import defaultdict
from torch.utils.data import DataLoader
from bert4torch.models import BaseModel
from bert4torch.snippets import sequence_padding, Callback, ListDataset, EarlyStopping, get_pool_emb
from bert4torch.callbacks import AdversarialTraining
from bert4torch.optimizers import extend_with_exponential_moving_average, get_linear_schedule_with_warmup, Lion
from bert4torch.tokenizers import Tokenizer
from bert4torch.losses import MultilabelCategoricalCrossentropy
from transformers import BertModel,AutoTokenizer
from PromptModel import PromptTable
import warnings

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#参数
warnings.filterwarnings("ignore")
maxlen = 150
batch_size = 16
# 加载标签字典
categories_label2id = {"PER": 1, "BOOK":2, "OFI": 3}
categories_id2label = dict((value, key) for key, value in categories_label2id.items())
ner_vocab_size = len(categories_label2id)
ner_head_size = 64

# BERT base
config_path = './GujiBERT/config.json'
checkpoint_path = './GujiBERT/pytorch_model.bin'
dict_path = './GujiBERT/vocab.txt'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# 建立分词器
tokenizer = Tokenizer(dict_path, do_lower_case=True)

In [3]:
# 定义模型结构
class Model(BaseModel):
    def __init__(self):
        super().__init__()
        self.bert = BertModel.from_pretrained("./GujiBERT")
        self.prompttable = PromptTable(hidden_size=768,head_size=ner_head_size)
        
    def forward(self, token_ids):
        sequence_output = self.bert(input_ids=token_ids, output_hidden_states=True)['last_hidden_state']
        logit = self.prompttable(sequence_output, token_ids.gt(0).long())
        return logit


In [4]:
#加载模型
model1 = Model().to(device)
model1.load_weights(f'./bestmodel/model1.pt')
model2 = Model().to(device)
model2.load_weights(f'./bestmodel/model2.pt')
model3 = Model().to(device)
model3.load_weights(f'./bestmodel/model3.pt')
model4 = Model().to(device)
model4.load_weights(f'./bestmodel/model4.pt')
model5 = Model().to(device)
model5.load_weights(f'./bestmodel/model5.pt')

Some weights of the model checkpoint at ./GujiBERT were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at ./GujiBERT and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably T

In [6]:
def package_result(context, prediction):
    textdict = defaultdict()
    kdict =  defaultdict()
    for i, char in enumerate(context):
        textdict[i] = char
        kdict[i] = False
    entities = {}
    for entity in prediction:
        start = entity[0]
        end = entity[1]
        label = entity[2]
        if kdict[start] == False and kdict[end]  == False:
            textdict[start] = '{' + textdict[start]
            textdict[end] = textdict[end] + '|'+label + '}'
            kdict[start] = True
            kdict[end] = True
    t = ''
    for k, v in textdict.items():
        t += v
    return t[3:]

In [9]:
#加载测试集
test = load_json('./GuNER2023_test_public.txt')

In [19]:
#模型一的预测结果
labellist = []
threshold = 0
for text in test:
    tokens = tokenizer.tokenize(text, maxlen=maxlen)
    mapping = tokenizer.rematch(text, tokens)
    token_ids = torch.tensor([tokenizer.tokens_to_ids(tokens)], device=device).long()
    output = model1.predict(token_ids)
    output = output[0]
    S = set()
    for start, end in zip(*np.where(output.cpu() > threshold)):
        if start <= end :
            for rel in range(1, 4):
                if output[start, rel] > threshold and output[rel, end] > threshold and output[end, start] > threshold and output[start, end] > threshold:
                    S.add((mapping[start][0],mapping[end][-1], categories_id2label[rel]))
    S = list(S)
    labelset = defaultdict(int)
    for entity in S:
        labelset[entity] += 1
    labellist.append(labelset)

In [20]:
#模型二的预测结果
for i,text in enumerate(test,0):
    tokens = tokenizer.tokenize(text, maxlen=maxlen)
    mapping = tokenizer.rematch(text, tokens)
    token_ids = torch.tensor([tokenizer.tokens_to_ids(tokens)], device=device).long()
    output = model2.predict(token_ids)
    output = output[0]
    S = set()
    for start, end in zip(*np.where(output.cpu() > threshold)):
        if start <= end :
            for rel in range(1, 4):
                if output[start, rel] > threshold and output[rel, end] > threshold and output[end, start] > threshold and output[start, end] > threshold:
                    S.add((mapping[start][0],mapping[end][-1], categories_id2label[rel]))
    S = list(S)
    for entity in S:
        labellist[i][entity] += 1

In [21]:
#模型三的预测结果
for i,text in enumerate(test,0):
    tokens = tokenizer.tokenize(text, maxlen=maxlen)
    mapping = tokenizer.rematch(text, tokens)
    token_ids = torch.tensor([tokenizer.tokens_to_ids(tokens)], device=device).long()
    output = model3.predict(token_ids)
    output = output[0]
    S = set()
    for start, end in zip(*np.where(output.cpu() > threshold)):
        if start <= end :
            for rel in range(1, 4):
                if output[start, rel] > threshold and output[rel, end] > threshold and output[end, start] > threshold and output[start, end] > threshold:
                    S.add((mapping[start][0],mapping[end][-1], categories_id2label[rel]))
    S = list(S)
    for entity in S:
        labellist[i][entity] += 1

In [22]:
#模型四的预测结果
for i,text in enumerate(test,0):
    tokens = tokenizer.tokenize(text, maxlen=maxlen)
    mapping = tokenizer.rematch(text, tokens)
    token_ids = torch.tensor([tokenizer.tokens_to_ids(tokens)], device=device).long()
    output = model4.predict(token_ids)
    output = output[0]
    S = set()
    for start, end in zip(*np.where(output.cpu() > threshold)):
        if start <= end :
            for rel in range(1, 4):
                if output[start, rel] > threshold and output[rel, end] > threshold and output[end, start] > threshold and output[start, end] > threshold:
                    S.add((mapping[start][0],mapping[end][-1], categories_id2label[rel]))
    S = list(S)
    for entity in S:
        labellist[i][entity] += 1

In [23]:
#模型五的预测结果
for i,text in enumerate(test,0):
    tokens = tokenizer.tokenize(text, maxlen=maxlen)
    mapping = tokenizer.rematch(text, tokens)
    token_ids = torch.tensor([tokenizer.tokens_to_ids(tokens)], device=device).long()
    output = model5.predict(token_ids)
    output = output[0]
    S = set()
    for start, end in zip(*np.where(output.cpu() > threshold)):
        if start <= end :
            for rel in range(1, 4):
                if output[start, rel] > threshold and output[rel, end] > threshold and output[end, start] > threshold and output[start, end] > threshold:
                    S.add((mapping[start][0],mapping[end][-1], categories_id2label[rel]))
    S = list(S)
    for entity in S:
        labellist[i][entity] += 1

In [31]:
#预测结果融合
La = []
for labelset in labellist:
    la = []
    for label,value in labelset.items():
        if value >= 3:
            la.append(label)
    La.append(la)

In [34]:
def save2file(filename, prediction):
    with open(filename, "w", encoding="utf-8") as fw:
        for re in prediction:
            fw.write(re + '\n')

In [35]:
#生成结果文件
test = load_json('./GuNER2023_test_public.txt')
predictions = []
for i,text in enumerate(test,0):
    prediction = La[i]
    R = package_result(text, prediction)
    predictions.append(R)
save2file('pred.txt', predictions)