In [1]:
#第10章/加载编码器
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('hfl/rbt3')

tokenizer

PreTrainedTokenizerFast(name_or_path='hfl/rbt3', vocab_size=21128, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})

In [2]:
#第10章/编码测试
out = tokenizer.batch_encode_plus(
    [[
        '海', '钓', '比', '赛', '地', '点', '在', '厦', '门', '与', '金', '门', '之', '间',
        '的', '海', '域', '。'
    ],
     [
         '这', '座', '依', '山', '傍', '水', '的', '博', '物', '馆', '由', '国', '内', '一',
         '流', '的', '设', '计', '师', '主', '持', '设', '计', '。'
     ]],
    truncation=True,
    padding=True,
    return_tensors='pt',
    max_length=20,
    is_split_into_words=True)

#还原编码为句子
print(tokenizer.decode(out['input_ids'][0]))
print(tokenizer.decode(out['input_ids'][1]))

for k, v in out.items():
    print(k, v)

[CLS] 海 钓 比 赛 地 点 在 厦 门 与 金 门 之 间 的 海 域 。 [SEP]
[CLS] 这 座 依 山 傍 水 的 博 物 馆 由 国 内 一 流 的 设 计 [SEP]
input_ids tensor([[ 101, 3862, 7157, 3683, 6612, 1765, 4157, 1762, 1336, 7305,  680, 7032,
         7305,  722, 7313, 4638, 3862, 1818,  511,  102],
        [ 101, 6821, 2429,  898, 2255,  988, 3717, 4638, 1300, 4289, 7667, 4507,
         1744, 1079,  671, 3837, 4638, 6392, 6369,  102]])
token_type_ids tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
attention_mask tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])


In [3]:
#第10章/定义数据集
import torch
from datasets import load_dataset, load_from_disk


class Dataset(torch.utils.data.Dataset):
    def __init__(self, split):

        #在线加载数据集
        #dataset = load_dataset(path='peoples_daily_ner', split=split)

        #离线加载数据集
        dataset = load_from_disk(
            dataset_path='./data/peoples_daily_ner')[split]

        self.dataset = dataset

        #dataset.features['ner_tags'].feature.num_classes
        #7

        #dataset.features['ner_tags'].feature.names
        #['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC']

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        tokens = self.dataset[i]['tokens']
        labels = self.dataset[i]['ner_tags']

        return tokens, labels


dataset = Dataset('train')

tokens, labels = dataset[0]
print(tokens), print(labels)

len(dataset)

['海', '钓', '比', '赛', '地', '点', '在', '厦', '门', '与', '金', '门', '之', '间', '的', '海', '域', '。']
[0, 0, 0, 0, 0, 0, 0, 5, 6, 0, 5, 6, 0, 0, 0, 0, 0, 0]


20865

In [4]:
# for i in range(100, 200):
#     text = dataset[i][0][:15]
#     label = dataset[i][1][:15]

#     if sum(label) == 0:
#         continue

#     text = ' '.join(text)
#     label = ' '.join([str(j) for j in label])

#     print(text)
#     print(label)
#     print()

In [5]:
#第10章/定义计算设备
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

device

'cuda'

In [6]:
#第10章/定义数据整理函数
def collate_fn(data):
    tokens = [i[0] for i in data]
    labels = [i[1] for i in data]

    #编码
    inputs = tokenizer.batch_encode_plus(tokens,
                                         truncation=True,
                                         padding=True,
                                         return_tensors='pt',
                                         max_length=512,
                                         is_split_into_words=True)

    #求一批数据中最长的句子长度
    lens = inputs['input_ids'].shape[1]

    #在labels的头尾补充7，把所有的labels补充成统一的长度
    for i in range(len(labels)):
        labels[i] = [7] + labels[i]
        labels[i] += [7] * lens
        labels[i] = labels[i][:lens]

    #把编码结果移动到计算设备
    for k, v in inputs.items():
        inputs[k] = v.to(device)

    #把统一长度的labels组装成矩阵，并移动到计算设备
    labels = torch.LongTensor(labels).to(device)

    return inputs, labels

In [7]:
#第10章/数据整理函数试算
#模拟一批数据
data = [
    ([
        '海', '钓', '比', '赛', '地', '点', '在', '厦', '门', '与', '金', '门', '之', '间',
        '的', '海', '域', '。'
    ], [0, 0, 0, 0, 0, 0, 0, 5, 6, 0, 5, 6, 0, 0, 0, 0, 0, 0]),
    ([
        '这', '座', '依', '山', '傍', '水', '的', '博', '物', '馆', '由', '国', '内', '一',
        '流', '的', '设', '计', '师', '主', '持', '设', '计', '，', '整', '个', '建', '筑',
        '群', '精', '美', '而', '恢', '宏', '。'
    ], [
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
    ]),
]

#试算
inputs, labels = collate_fn(data)

for k, v in inputs.items():
    print(k, v.shape)

print('labels', labels.shape)

input_ids torch.Size([2, 37])
token_type_ids torch.Size([2, 37])
attention_mask torch.Size([2, 37])
labels torch.Size([2, 37])


In [8]:
#第10章/数据加载器
loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=16,
                                     collate_fn=collate_fn,
                                     shuffle=True,
                                     drop_last=True)

len(loader)

1304

In [9]:
#第10章/查看数据样例
for i, (inputs, labels) in enumerate(loader):
    break

print(tokenizer.decode(inputs['input_ids'][0]))
print(labels[0])

for k, v in inputs.items():
    print(k, v.shape)

[CLS] 不 过 ， 从 此 以 后 ， 惠 普 公 司 便 成 了 他 们 的 忠 诚 客 户 。 [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
tensor([7, 0, 0, 0, 0, 0, 0, 0, 0, 3, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0

In [10]:
#第10章/加载预训练模型
from transformers import AutoModel

pretrained = AutoModel.from_pretrained('hfl/rbt3')

pretrained.to(device)

#统计参数量
print(sum(i.numel() for i in pretrained.parameters()) / 10000)

Some weights of the model checkpoint at hfl/rbt3 were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


3847.68


In [11]:
#第10章/模型试算
#[b, lens] -> [b, lens, 768]
pretrained(**inputs).last_hidden_state.shape

torch.Size([16, 171, 768])

In [12]:
#第10章/定义下游模型
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        #标识当前模型是否处于tuning模式
        self.tuning = False
        #当处于tuning模式时backbone应该属于当前模型的一部分，否则该变量为空
        self.pretrained = None

        #当前模型的神经网络层
        self.rnn = torch.nn.GRU(input_size=768, hidden_size=768, batch_first=True)
        self.fc = torch.nn.Linear(in_features=768, out_features=8)

    def forward(self, inputs):
        #根据当前模型是否处于tuning模式而使用外部backbone或内部backbone计算
        if self.tuning:
            out = self.pretrained(**inputs).last_hidden_state
        else:
            with torch.no_grad():
                out = pretrained(**inputs).last_hidden_state

        #backbone抽取的特征输入rnn网络进一步抽取特征
        out, _ = self.rnn(out)

        #rnn网络抽取的特征最后输入fc神经网络分类
        out = self.fc(out).softmax(dim=2)

        return out

    #切换下游任务模型的的tuning模式
    def fine_tuning(self, tuning):
        self.tuning = tuning
        #tuning模式时，训练backbone的参数
        if tuning:
            for i in pretrained.parameters():
                i.requires_grad = True

            pretrained.train()
            self.pretrained = pretrained
        #非tuning模式时，不训练backbone的参数
        else:
            for i in pretrained.parameters():
                i.requires_grad_(False)

            pretrained.eval()
            self.pretrained = None


model = Model()

model.to(device)

model(inputs).shape

torch.Size([16, 171, 8])

In [13]:
#第10章/对计算结果和label变形,并且移除pad
def reshape_and_remove_pad(outs, labels, attention_mask):
    #变形,便于计算loss
    #[b, lens, 8] -> [b*lens, 8]
    outs = outs.reshape(-1, 8)
    #[b, lens] -> [b*lens]
    labels = labels.reshape(-1)

    #忽略对pad的计算结果
    #[b, lens] -> [b*lens - pad]
    select = attention_mask.reshape(-1) == 1
    outs = outs[select]
    labels = labels[select]

    return outs, labels


reshape_and_remove_pad(torch.randn(2, 3, 8), torch.ones(2, 3),
                       torch.ones(2, 3))

(tensor([[-0.0526,  0.1013, -0.0173, -0.7658,  1.5053,  0.7575,  0.0168, -0.2533],
         [-0.9728,  1.1428, -0.8979, -1.1862, -0.6049,  0.4607,  0.2360,  1.0206],
         [-0.3074, -0.1413,  0.7964,  1.0810,  1.0588,  2.8770,  0.5829, -0.0668],
         [-0.4274, -0.0931, -1.0921, -0.7383, -0.0292, -0.6288, -3.0649, -0.3452],
         [-1.5576, -0.7772, -1.2155, -0.6495, -0.1605,  2.0787, -0.1997, -0.1986],
         [ 0.1147,  1.3831, -0.1156,  0.8515, -0.3147,  0.7072, -0.4293,  0.9322]]),
 tensor([1., 1., 1., 1., 1., 1.]))

In [14]:
#第10章/获取正确数量和总数
def get_correct_and_total_count(labels, outs):
    #[b*lens, 8] -> [b*lens]
    outs = outs.argmax(dim=1)
    correct = (outs == labels).sum().item()
    total = len(labels)

    #计算除了0以外元素的正确率,因为0太多了,包括的话,正确率很容易虚高
    select = labels != 0
    outs = outs[select]
    labels = labels[select]
    correct_content = (outs == labels).sum().item()
    total_content = len(labels)

    return correct, total, correct_content, total_content


get_correct_and_total_count(torch.ones(16), torch.randn(16, 8))

(1, 16, 1, 16)

In [15]:
#第10章/训练
from transformers import AdamW
from transformers.optimization import get_scheduler


def train(epochs):
    lr = 2e-5 if model.tuning else 5e-4

    optimizer = AdamW(model.parameters(), lr=lr)
    criterion = torch.nn.CrossEntropyLoss()
    scheduler = get_scheduler(name='linear',
                              num_warmup_steps=0,
                              num_training_steps=len(loader) * epochs,
                              optimizer=optimizer)

    model.train()
    for epoch in range(epochs):
        for step, (inputs, labels) in enumerate(loader):
            #模型计算
            #[b, lens] -> [b, lens, 8]
            outs = model(inputs)

            #对outs和label变形,并且移除pad
            #outs -> [b, lens, 8] -> [c, 8]
            #labels -> [b, lens] -> [c]
            outs, labels = reshape_and_remove_pad(outs, labels,
                                                  inputs['attention_mask'])

            #梯度下降
            loss = criterion(outs, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

            if step % (len(loader) * epochs // 30) == 0:
                counts = get_correct_and_total_count(labels, outs)
                accuracy = counts[0] / counts[1]
                accuracy_content = counts[2] / counts[3]
                lr = optimizer.state_dict()['param_groups'][0]['lr']

                print(epoch, step, loss.item(), lr, accuracy, accuracy_content)

        torch.save(model, 'model/中文命名实体识别.model')

In [16]:
#第10章/两段式训练第1步，训练下游任务模型
model.fine_tuning(False)
print(sum(p.numel() for p in model.parameters()) / 10000)
#train(1)

354.9704


In [17]:
#第10章/两段式训练第2步，同时训练下游任务模型和预训练模型
model.fine_tuning(True)
print(sum(p.numel() for p in model.parameters()) / 10000)
#train(5)

4202.6504


In [39]:
#第10章/测试
def test():
    #加载训练完的模型
    model_load = torch.load('model/中文命名实体识别.model')
    model_load.tuning = True
    model_load.eval()
    model_load.to(device)

    #测试数据集加载器
    loader_test = torch.utils.data.DataLoader(dataset=Dataset('validation'),
                                              batch_size=128,
                                              collate_fn=collate_fn,
                                              shuffle=True,
                                              drop_last=True)

    correct = 0
    total = 0

    correct_content = 0
    total_content = 0

    #遍历测试数据集
    for step, (inputs, labels) in enumerate(loader_test):

        #测试5个批次即可，不全部全部遍历
        if step == 5:
            break
        print(step)

        #计算
        with torch.no_grad():
            #[b, lens] -> [b, lens, 8] -> [b, lens]
            outs = model_load(inputs)

        #对outs和label变形,并且移除pad
        #outs -> [b, lens, 8] -> [c, 8]
        #labels -> [b, lens] -> [c]
        outs, labels = reshape_and_remove_pad(outs, labels,
                                              inputs['attention_mask'])

        #统计正确数量
        counts = get_correct_and_total_count(labels, outs)
        correct += counts[0]
        total += counts[1]
        correct_content += counts[2]
        total_content += counts[3]

    print(correct / total, correct_content / total_content)


test()

0
1
2
3
4
0.9923819197562215 0.9607843137254902


In [41]:
#第10章/预测
def predict():
    #加载模型
    model_load = torch.load('model/中文命名实体识别.model')
    model_load.tuning = True
    model_load.eval()
    model_load.to(device)

    #测试数据集加载器
    loader_test = torch.utils.data.DataLoader(dataset=Dataset('validation'),
                                              batch_size=32,
                                              collate_fn=collate_fn,
                                              shuffle=True,
                                              drop_last=True)

    #取一个批次的数据
    for i, (inputs, labels) in enumerate(loader_test):
        break

    #计算
    with torch.no_grad():
        #[b, lens] -> [b, lens, 8] -> [b, lens]
        outs = model_load(inputs).argmax(dim=2)

    for i in range(32):
        #移除pad
        select = inputs['attention_mask'][i] == 1
        input_id = inputs['input_ids'][i, select]
        out = outs[i, select]
        label = labels[i, select]

        #输出原句子
        print(tokenizer.decode(input_id).replace(' ', ''))

        #输出tag
        for tag in [label, out]:
            s = ''
            for j in range(len(tag)):
                if tag[j] == 0:
                    s += '·'
                    continue
                s += tokenizer.decode(input_id[j])
                s += str(tag[j].item())

            print(s)
        print('==========================')


predict()

[CLS]要建立健全信息网络，增强工作的预见性、前瞻性。[SEP]
[CLS]7·······················[SEP]7
[CLS]7·······················[SEP]7
[CLS]该书是迄今为止国内较为全面、系统地研究周恩来经济思想的一部专著，有助于深化对周恩来思想的研究。[SEP]
[CLS]7···················周1恩2来2················周1恩2来2······[SEP]7
[CLS]7···················周1恩2来2················周1恩2来2······[SEP]7
[CLS]会议之后，国际原油市场仍明显供大于求，价格十分疲软。[SEP]
[CLS]7··························[SEP]7
[CLS]7··························[SEP]7
[CLS]两个月后少女平静地离去，她的身边簇拥着俊平的朋友们，枕边还放着俊平为她捎去的书。[SEP]
[CLS]7···················俊1平2··········俊1平2·······[SEP]7
[CLS]7···················俊1平2··········俊1平2·······[SEP]7
[CLS]第二，在改制后的国有大中型企业设置党组。[SEP]
[CLS]7····················[SEP]7
[CLS]7····················[SEP]7
[CLS]中央政府严格遵守《基本法》，对于香港首届立法会选举事务未予任何干预。[SEP]
[CLS]7················香3港4首4届4立4法4会4···········[SEP]7
[CLS]7················香3港4首4届4立4法4会4···········[SEP]7
[CLS]对于发展中国家而言，汽车工业的集中化趋势显然是严峻的挑战。[SEP]
[CLS]7·····························[SEP]7
[CLS]7·····························[SEP]7
[CLS]时隔不久，中央电视台也对武汉电视台敞开了大门，把他们纳入了[UNK]科普宣传国家队[UNK