In [1]:
from transformers import AutoTokenizer
import torch
from datasets import load_dataset, load_from_disk
import torch.cuda as cuda

#加载分词器
tokenizer = AutoTokenizer.from_pretrained('hfl/chinese-roberta-wwm-ext')

device = torch.device('cuda' if cuda.is_available() else 'cpu')
print(device)

cuda


In [3]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, split):
        #离线加载数据集
        dataset = load_from_disk(dataset_path='./msra')[split]

        #过滤掉太长的句子
        def f(data):
            return len(data['tokens']) <= 512 - 2

        dataset = dataset.filter(f)

        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        tokens = self.dataset[i]['tokens']
        labels = self.dataset[i]['ner_tags']

        return tokens, labels


dataset = Dataset('train')
print(len(dataset))
dataset[0]

Loading cached processed dataset at d:\暂存\ResumeSystem\ner\xt\train\cache-343894c14063acb7.arrow


19999


(['2',
  '0',
  '0',
  '9',
  '年',
  '：',
  '李',
  '民',
  '基',
  '《',
  'E',
  't',
  'e',
  'r',
  'n',
  'a',
  'l',
  '#',
  'S',
  'u',
  'm',
  'm',
  'e',
  'r',
  '》'],
 [0, 0, 0, 0, 0, 0, 1, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [7]:
#数据整理函数
def collate_fn(data):
    tokens = [i[0] for i in data]
    labels = [i[1] for i in data]

    inputs = tokenizer.batch_encode_plus(tokens,
                                         truncation=True,
                                         padding=True,
                                         return_tensors='pt',
                                         is_split_into_words=True)

    lens = inputs['input_ids'].shape[1]

    for i in range(len(labels)):
        labels[i] = [7] + labels[i]
        labels[i] += [7] * lens
        labels[i] = labels[i][:lens]

    # 将 inputs 和 labels 移动到设备上
    inputs = {key: value.to(device) for key, value in inputs.items()}
    labels = torch.LongTensor(labels).to(device)

    return inputs, labels

In [4]:
# 数据加载器
loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=8,
                                     collate_fn=collate_fn,
                                     shuffle=True,
                                     drop_last=True)

for i, (inputs, labels) in enumerate(loader):
    inputs = {k: v.to(device) for k, v in inputs.items()}
    labels = labels.to(device)
    break


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [4]:
from transformers import AutoModel

#加载预训练模型
pretrained = AutoModel.from_pretrained('hfl/chinese-roberta-wwm-ext')
pretrained = pretrained.to(device)
pretrained(**inputs).last_hidden_state.shape

Some weights of the model checkpoint at hfl/chinese-roberta-wwm-ext were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


NameError: name 'inputs' is not defined

In [5]:
# 定义下游模型
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.tuning = False
        self.pretrained = None

        self.rnn = torch.nn.GRU(768, 768, batch_first=True)
        self.fc = torch.nn.Linear(768, 8)
        
    def forward(self, inputs):
        if self.tuning:
            out = self.pretrained(**inputs).last_hidden_state
        else:
            with torch.no_grad():
                out = pretrained(**inputs).last_hidden_state

        out, _ = self.rnn(out)
        
        out = self.fc(out).softmax(dim=2)
        
        return out

    def fine_tuning(self, tuning):
        self.tuning = tuning
        if tuning:
            for i in pretrained.parameters():
                i.requires_grad = True

            pretrained.train()
            self.pretrained = pretrained
        else:
            for i in pretrained.parameters():
                i.requires_grad_(False)

            pretrained.eval()
            self.pretrained = None


model = Model()
model = model.to(device)
model(inputs).shape

NameError: name 'inputs' is not defined

In [9]:
#对计算结果和label变形,并且移除pad
def reshape_and_remove_pad(outs, labels, attention_mask):
    #变形,便于计算loss
    #[b, lens, 8] -> [b*lens, 8]
    outs = outs.reshape(-1, 8)
    #[b, lens] -> [b*lens]
    labels = labels.reshape(-1)

    #忽略对pad的计算结果
    #[b, lens] -> [b*lens - pad]
    select = attention_mask.reshape(-1) == 1
    outs = outs[select]
    labels = labels[select]

    return outs, labels


reshape_and_remove_pad(torch.randn(2, 3, 8), torch.ones(2, 3),
                       torch.ones(2, 3))

(tensor([[ 0.3735,  0.1839, -1.2947, -0.5799,  1.4103,  1.9831,  1.1971,  0.4692],
         [-0.4211,  0.2562,  0.1699, -0.2309,  0.3574,  0.7039,  0.0447,  1.1816],
         [-0.0445, -0.8757, -1.3016,  1.7092, -0.9039, -1.2696,  0.9966,  0.3699],
         [ 1.0213, -0.4037, -1.7713,  0.2936,  1.2625,  1.4149, -0.2500, -1.3706],
         [-0.1633,  0.1126, -0.5640, -0.5190, -0.7649, -0.6759, -0.1617, -0.0638],
         [-0.0021,  1.6459, -0.0471, -0.4606,  0.1794,  0.8733, -0.3115, -1.3731]]),
 tensor([1., 1., 1., 1., 1., 1.]))

In [11]:
# 获取正确数量和总数
def get_correct_and_total_count(labels, outs):
    #[b*lens, 8] -> [b*lens]
    outs = outs.argmax(dim=1)
    correct = (outs == labels).sum().item()
    total = len(labels)

    # 计算除了0以外元素的正确率
    select = labels != 0
    outs = outs[select]
    labels = labels[select]
    correct_content = (outs == labels).sum().item()
    total_content = len(labels)

    return correct, total, correct_content, total_content

get_correct_and_total_count(torch.ones(16), torch.randn(16, 8))

(3, 16, 3, 16)

In [9]:
from transformers import AdamW
from torch.optim.lr_scheduler import StepLR
def train(epochs):
    lr = 1e-5 if model.tuning else 3e-4
    optimizer = AdamW(model.parameters(), lr=lr)
    criterion = torch.nn.CrossEntropyLoss()
    scheduler = StepLR(optimizer, step_size=4, gamma=0.3)
    
    model.train()
    for epoch in range(epochs):
        for step, (inputs, labels) in enumerate(loader):
            outs = model(inputs)
            outs, labels = reshape_and_remove_pad(outs, labels,
                                                  inputs['attention_mask'])
            loss = criterion(outs, labels)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            if step % 400 == 0:
                counts = get_correct_and_total_count(labels, outs)

                accuracy = counts[0] / counts[1]
                accuracy_content = counts[2] / counts[3]

                print(epoch, step, loss.item(), accuracy, accuracy_content)
        scheduler.step()
        torch.save(model, 'ner5.model')

In [10]:
model.fine_tuning(False)
print(sum(p.numel() for p in model.parameters()) / 10000)
train(10)

354.9704




0 0 1.4061245918273926 0.8676470588235294 0.8125
0 400 1.3827636241912842 0.8911917098445595 0.8958333333333334
0 800 1.336877465248108 0.9354838709677419 0.851063829787234
0 1200 1.4235705137252808 0.8524590163934426 0.7954545454545454
0 1600 1.3927167654037476 0.8717948717948718 0.8305084745762712
0 2000 1.3143255710601807 0.9586919104991394 0.765625
0 2400 1.3752126693725586 0.8958333333333334 0.8653846153846154
1 0 1.3365840911865234 0.941908713692946 0.7894736842105263
1 400 1.395945429801941 0.8661417322834646 0.9508196721311475
1 800 1.3376469612121582 0.935251798561151 0.7777777777777778
1 1200 1.3243870735168457 0.9518987341772152 0.8769230769230769
1 1600 1.3798080682754517 0.8955223880597015 0.8620689655172413
1 2000 1.4270416498184204 0.8473282442748091 0.847457627118644
1 2400 1.3653578758239746 0.9078947368421053 0.8076923076923077
2 0 1.3485286235809326 0.9254658385093167 0.6883116883116883
2 400 1.4106417894363403 0.8541666666666666 0.7142857142857143
2 800 1.3221021890

In [11]:
model.fine_tuning(True)
print(sum(p.numel() for p in model.parameters()) / 10000)
train(15)

10581.7352
0 0 1.281691074371338 0.99581589958159 0.9818181818181818
0 400 1.4263242483139038 0.84375 0.7931034482758621
0 800 1.3035809993743896 0.9774436090225563 0.9361702127659575
0 1200 1.4049749374389648 0.861878453038674 0.576271186440678
0 1600 1.3129820823669434 0.961038961038961 0.9464285714285714
0 2000 1.2954390048980713 0.9794520547945206 0.9577464788732394
0 2400 1.326935052871704 0.9532710280373832 0.9411764705882353
1 0 1.307201862335205 0.9702970297029703 0.9782608695652174
1 400 1.4246934652328491 0.8504672897196262 0.7837837837837838
1 800 1.2742036581039429 1.0 1.0
1 1200 1.286851406097412 0.9873417721518988 0.9814814814814815
1 1600 1.5959538221359253 0.664 0.8620689655172413
1 2000 1.2948931455612183 0.9776785714285714 0.9074074074074074
1 2400 1.2933577299118042 0.9789915966386554 0.8979591836734694
2 0 1.3438622951507568 0.9298780487804879 0.7529411764705882
2 400 1.2742880582809448 1.0 1.0
2 800 1.339837670326233 0.9343065693430657 0.8615384615384616
2 1200 1.2

In [12]:
def test():
    model_load = torch.load('ner2.model')
    model_load.eval()

    loader_test = torch.utils.data.DataLoader(dataset=Dataset('validation'),
                                              batch_size=128,
                                              collate_fn=collate_fn,
                                              shuffle=True,
                                              drop_last=True)

    correct = 0
    total = 0

    correct_content = 0
    total_content = 0

    for step, (inputs, labels) in enumerate(loader_test):
        if step == 5:
            break
        print(step)

        with torch.no_grad():
            #[b, lens] -> [b, lens, 8] -> [b, lens]
            outs = model_load(inputs)

        #对outs和label变形,并且移除pad
        #outs -> [b, lens, 8] -> [c, 8]
        #labels -> [b, lens] -> [c]
        outs, labels = reshape_and_remove_pad(outs, labels,
                                              inputs['attention_mask'])

        counts = get_correct_and_total_count(labels, outs)
        correct += counts[0]
        total += counts[1]
        correct_content += counts[2]
        total_content += counts[3]

    print(correct / total, correct_content / total_content)


test()

Loading cached processed dataset at d:\暂存\ResumeSystem\ner\data\validation\cache-26bc23ad408c2b6f.arrow


0
1
2
3
4
0.9970202622169249 0.9840819886611426


In [13]:
def custom_predict(sentence):

    model_load = torch.load('ner5.model')
    model_load = model_load.to(device)
    model_load.eval()

    inputs = tokenizer.encode_plus([sentence],
                                   truncation=True,
                                   padding=True,
                                   return_tensors='pt',
                                   is_split_into_words=True).to(device)

    with torch.no_grad():
        outputs = model_load(inputs)

    # [b, lens] -> [b, lens]
    preds = outputs.argmax(dim=2)[0]

    result = ''

    res = [set() for _ in range(3)]
    tmp = ''
    current_flag = -1

    for i in range(len(preds)):
        if inputs['attention_mask'][0][i] == 1:
            result += tokenizer.decode(inputs['input_ids'][0][i])+' '
            result += str(preds[i].item())+' '
            num = preds[i].item()
            # num 不为0和7和#表示该词为关键词
            if (num != 0 and num != 7 and num != '#'):
                # 关键词开始为奇数
                if (num & 1):
                    # 将形如 广5东6广5州6 拆成两个词
                    if (len(tmp) > 1):
                        res[(current_flag-1)//2].add(tmp)
                        tmp = ''
                    current_flag = num
                    tmp += tokenizer.decode(inputs['input_ids'][0][i])

                    # 防止形如 X4X4 出现
                elif (num & 1 == 0 and current_flag != -1):
                    tmp += tokenizer.decode(inputs['input_ids'][0][i])
            else:
                if (len(tmp) > 1):
                    # current_flag 1对应姓名下标0，3对应组织下表1，5对应地点下表2
                    res[(current_flag-1)//2].add(tmp)
                    tmp = ''
                    current_flag = -1

    return res

In [23]:
input_sentence = "李白在上海的华东理工大学读书"
output_prediction = custom_predict(input_sentence)
print(
    f'姓名{list(output_prediction[0])},组织{list(output_prediction[1])},地点{list(output_prediction[2])}')

姓名[],组织['统战统战部'],地点[]
