In [1]:
#第14章/读取字典
import pandas as pd

vocab = pd.read_csv('data/msr_paraphrase_vocab.csv', index_col='word')
vocab_r = pd.read_csv('data/msr_paraphrase_vocab.csv', index_col='token')

vocab, vocab_r

(            token
 word             
 <PAD>           0
 <SOS>           1
 <EOS>           2
 <NUM>           3
 <UNK>           4
 ...           ...
 eastbound   14784
 clouds      14785
 repave      14786
 complained  14787
 dominate    14788
 
 [14789 rows x 1 columns],
              word
 token            
 0           <PAD>
 1           <SOS>
 2           <EOS>
 3           <NUM>
 4           <UNK>
 ...           ...
 14784   eastbound
 14785      clouds
 14786      repave
 14787  complained
 14788    dominate
 
 [14789 rows x 1 columns])

In [2]:
#第14章/定义数据集
import torch


class MsrDataset(torch.utils.data.Dataset):
    def __init__(self):
        data = pd.read_csv('data/msr_paraphrase_data.csv')
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        return self.data.iloc[i]


dataset = MsrDataset()

len(dataset), dataset[0]

(5801,
 same                                                        1
 s1_lens                                                    16
 s2_lens                                                    17
 pad_lens                                                   39
 sent        1,11,12,13,14,15,16,17,18,19,20,21,22,13,23,2,...
 Name: 0, dtype: object)

In [3]:
#第14章/定义数据整理函数
import numpy as np


def collate_fn(data):
    #取出数据
    same = [i['same'] for i in data]
    sent = [i['sent'] for i in data]
    s1_lens = [i['s1_lens'] for i in data]
    s2_lens = [i['s2_lens'] for i in data]
    pad_lens = [i['pad_lens'] for i in data]

    seg = []
    for i in range(len(sent)):
        #seg的形状和sent一样,但是内容不一样
        #补PAD的位置是0,s1的位置是1,s2的位置是2
        seg.append([1] * s1_lens[i] + [2] * s2_lens[i] + [0] * pad_lens[i])

    #sent由字符型转换为list
    sent = [np.array(i.split(','), dtype=np.int) for i in sent]

    same = torch.LongTensor(same)
    sent = torch.LongTensor(sent)
    seg = torch.LongTensor(seg)

    return same, sent, seg


collate_fn([dataset[0], dataset[1]])



(tensor([1, 0]),
 tensor([[ 1, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 13, 23,  2, 24, 25,
          26, 27, 28, 18, 19, 11, 12, 13, 14, 20, 21, 22, 13, 23,  2,  0,  0,  0,
           0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
           0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 1, 29, 30, 31, 32, 33, 34, 18, 35, 25, 36, 37,  3, 38,  3,  3, 39,  2,
          29, 40, 31, 32, 37,  3, 38,  3, 41, 42, 43, 44, 25, 36, 38,  3,  3, 39,
          37,  3,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
           0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]]),
 tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
          2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 

In [4]:
#第14章/定义数据加载器
loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=32,
                                     shuffle=True,
                                     drop_last=True,
                                     collate_fn=collate_fn)

len(loader)

181

In [5]:
#第14章/查看数据样例
for i, (same, sent, seg) in enumerate(loader):
    break

same, sent.shape, seg.shape, sent[0], seg[0]

(tensor([0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0,
         0, 1, 0, 1, 0, 1, 1, 1]),
 torch.Size([32, 72]),
 torch.Size([32, 72]),
 tensor([   1,  555, 1538, 1159,  720, 1405,  359, 6912,   18, 5386,   38, 1757,
           42, 3992, 2125,    2, 1405, 6913,   65,  153, 1757,   42, 3992, 2125,
          154, 6722,    2,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0]),
 tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
         2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))

In [6]:
#第14章/定义随机替换函数
import random


def random_replace(sent):
    #sent = [b, 72]
    #不影响原来的sent
    sent = sent.clone()

    #替换矩阵,形状和sent一样,被替换过的位置是True,其他位置是False
    replace = sent == -1

    #遍历所有的字
    for i in range(len(sent)):
        for j in range(len(sent[i])):
            #如果是符号就不操作了,只替换字
            if sent[i, j] <= 10:
                continue

            #0.15的概率做操作
            if random.random() > 0.15:
                continue

            #被操作过的位置标记下,这里的操作包括什么也不做
            replace[i, j] = True

            #分概率做不同的操作
            p = random.random()

            #0.8的概率替换为mask
            if p < 0.8:
                sent[i, j] = vocab.loc['<MASK>'].token

            #0.1的概率不替换
            elif p < 0.9:
                pass

            #0.1的概率替换成随机字
            else:
                #随机一个不是符号的字
                rand_word = 0
                while rand_word <= 10:
                    rand_word = random.randint(0, len(vocab) - 1)
                sent[i, j] = rand_word

    return sent, replace


replace_sent, replace = random_replace(sent)

replace_sent[replace]

tensor([    5,     5,     5,  8342,     5,  7568,  2984,     5,     5,     5,
            5,     5,     5,     5,     5,     5,  7238,     5, 10831,     5,
         5635,     5,     5,     5,    25,  6697,     5,     5,     5,     5,
            5,     5,     5,     5,     5,     5,     5,     5,     5,     5,
            5,     5, 13424,     5,  3787,     5,     5,     5,  4711,     5,
          448,   704,     5,  4711,  5834, 12720,     5,     5,     5,     5,
           32,     5,     5,     5,     5,  1364,     5,   163,   238,     5,
            5,     5,     5,   394,     5,    25,     5,     5,     5,     5,
            5,     5,     5,     5, 13377,     5,     5,     5,     5,    18,
            5,     5,     5,     5,     5,    69,     5,     5,     5,   145,
            5,     5, 10282,     5,     5,     5,     5,     5,     5,     5,
            5,     5,     5,     5,     5,     5,     5,     5,     5,     5,
            5,     5,     5,  1740,     5,     5,     5,     5, 

In [7]:
#第14章/定义获取mask函数
def get_mask(seg):
    #key_padding_mask的定义方式为句子中是PAD的位置为True，否则是False
    key_padding_mask = seg == 0

    #在encode阶段不需要定义encode_attn_mask
    #定义为None或者全False都可以
    encode_attn_mask = torch.ones(72, 72) == -1

    return key_padding_mask, encode_attn_mask


key_padding_mask, encode_attn_mask = get_mask(seg)

key_padding_mask.shape, encode_attn_mask.shape, key_padding_mask[
    0], encode_attn_mask

(torch.Size([32, 72]),
 torch.Size([72, 72]),
 tensor([False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True]),
 tensor([[False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         ...,
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False]]))

In [8]:
#第14章/定义模型
class BERTModel(torch.nn.Module):
    def __init__(self):
        super().__init__()

        #定义词向量编码层
        self.sent_embed = torch.nn.Embedding(num_embeddings=len(vocab),
                                             embedding_dim=256)

        #定义seg编码层
        self.seg_embed = torch.nn.Embedding(num_embeddings=3,
                                            embedding_dim=256)

        #定义位置编码层
        self.position_embed = torch.nn.Parameter(torch.randn(72, 256) / 10)

        #定义编码层
        encoder_layer = torch.nn.TransformerEncoderLayer(d_model=256,
                                                         nhead=4,
                                                         dim_feedforward=256,
                                                         dropout=0.2,
                                                         activation='relu',
                                                         batch_first=True,
                                                         norm_first=True)

        #定义标准化层
        norm = torch.nn.LayerNorm(normalized_shape=256,
                                  elementwise_affine=True)

        #定义编码器
        self.encoder = torch.nn.TransformerEncoder(encoder_layer=encoder_layer,
                                                   num_layers=3,
                                                   norm=norm)

        #定义same输出层
        self.fc_same = torch.nn.Linear(in_features=256, out_features=2)

        #定义sent输出层
        self.fc_sent = torch.nn.Linear(in_features=256,
                                       out_features=len(vocab))

    def forward(self, sent, seg):
        #sent -> [b, 72]
        #seg -> [b, 72]

        #获取mask
        #[b, 72] -> [b, 72],[72, 72]
        key_padding_mask, encode_attn_mask = get_mask(seg)

        #编码,添加位置信息
        #[b, 72] -> [b, 72, 256]
        embed = self.sent_embed(sent) + self.seg_embed(
            seg) + self.position_embed

        #编码器计算
        #[b, 72, 256] -> [b, 72, 256]
        memory = self.encoder(src=embed,
                              mask=encode_attn_mask,
                              src_key_padding_mask=key_padding_mask)

        #计算输出,same的输出使用第0个词的信息计算
        #[b, 256] -> [b, 2]
        same = self.fc_same(memory[:, 0])
        #[b, 72, 256] -> [b, 72, V]
        sent = self.fc_sent(memory)

        return same, sent


model = BERTModel()

pred_same, pred_sent = model(sent, seg)

pred_same.shape, pred_sent.shape

(torch.Size([32, 2]), torch.Size([32, 72, 14789]))

In [9]:
#第14章/训练
def train():
    loss_func = torch.nn.CrossEntropyLoss()
    optim = torch.optim.Adam(model.parameters(), lr=1e-4)
    for epoch in range(2000):
        for i, (same, sent, seg) in enumerate(loader):
            #same = [b]
            #sent = [b, 72]
            #seg = [b, 72]

            #随机替换x中的某些字符,replace为是否被操作过的矩阵,这里的操作包括不替换
            #replace_sent = [b, 72]
            #replace = [b, 72]
            replace_sent, replace = random_replace(sent)

            #模型计算
            #[b, 72],[b, 72] -> [b, 2],[b, 72, V]
            pred_same, pred_sent = model(replace_sent, seg)

            #pred_sent = pred_sent.flatten(end_dim=1)
            #sent = sent.flatten()

            #只把被操作过的字提取出来
            #[b, 72, V] -> [replace, V]
            pred_sent = pred_sent[replace]

            #把被操作之前的字取出来
            #[b, 72] -> [replace]
            sent = sent[replace]

            #计算两份loss,再加权求和
            loss_same = loss_func(pred_same, same)
            loss_sent = loss_func(pred_sent, sent)
            loss = loss_same * 0.1 + loss_sent

            loss.backward()
            optim.step()
            optim.zero_grad()

        if epoch % 5 == 0:
            #计算same预测正确率
            pred_same = pred_same.argmax(dim=1)
            acc_same = (same == pred_same).sum().item() / len(same)

            #计算替换词预测正确率
            pred_sent = pred_sent.argmax(dim=1)
            acc_sent = (sent == pred_sent).sum().item() / len(sent)

            print(epoch, i, loss.item(), acc_same, acc_sent)


train()

0 180 7.583975791931152 0.71875 0.05202312138728324
5 180 7.01627254486084 0.71875 0.05357142857142857
10 180 6.79768705368042 0.8125 0.08021390374331551
15 180 6.868100166320801 0.71875 0.09947643979057591
20 180 6.631269454956055 0.6875 0.09375
25 180 6.179685115814209 0.6875 0.13043478260869565
30 180 6.160305023193359 0.6875 0.10650887573964497
35 180 6.393719673156738 0.6875 0.1187214611872146
40 180 6.344337463378906 0.8125 0.11518324607329843
45 180 5.619174957275391 0.65625 0.16304347826086957
50 180 6.080831527709961 0.625 0.08791208791208792
55 180 5.971240997314453 0.65625 0.10429447852760736
60 180 5.649734973907471 0.78125 0.14285714285714285
65 180 5.619324207305908 0.71875 0.14792899408284024
70 180 5.730544090270996 0.78125 0.13131313131313133
75 180 5.373783111572266 0.75 0.12206572769953052
80 180 4.9394941329956055 0.84375 0.1686046511627907
85 180 5.006924629211426 0.6875 0.1657754010695187
90 180 4.841911792755127 0.78125 0.15463917525773196
95 180 5.02555370330810

800 180 1.6947892904281616 1.0 0.5945945945945946
805 180 2.1447057723999023 1.0 0.5245098039215687
810 180 2.0740888118743896 1.0 0.5444444444444444
815 180 1.9558393955230713 0.96875 0.5277777777777778
820 180 2.089111804962158 0.96875 0.5536723163841808
825 180 1.8715225458145142 0.96875 0.6043956043956044
830 180 1.7022827863693237 0.9375 0.5632183908045977
835 180 2.0040524005889893 0.9375 0.5583756345177665
840 180 2.0123708248138428 0.96875 0.5459770114942529
845 180 1.785885214805603 0.96875 0.5661375661375662
850 180 1.9366830587387085 0.9375 0.5577889447236181
855 180 2.1804163455963135 0.96875 0.5195530726256983
860 180 1.804997444152832 0.9375 0.5751295336787565
865 180 1.9716240167617798 0.96875 0.5463917525773195
870 180 1.7420647144317627 0.96875 0.6388888888888888
875 180 2.02553391456604 0.90625 0.5153061224489796
880 180 1.9450682401657104 0.9375 0.6162790697674418
885 180 1.7774468660354614 0.96875 0.5860215053763441
890 180 1.6604479551315308 1.0 0.625
895 180 1.725

1595 180 1.5449641942977905 0.96875 0.6453488372093024
1600 180 1.120186448097229 1.0 0.7103825136612022
1605 180 0.812895655632019 0.96875 0.8
1610 180 1.009718418121338 0.96875 0.7417582417582418
1615 180 1.2270549535751343 0.96875 0.6779661016949152
1620 180 1.0853394269943237 0.96875 0.7252747252747253
1625 180 0.9983455538749695 1.0 0.7543859649122807
1630 180 0.8749478459358215 0.96875 0.7514792899408284
1635 180 0.9588981866836548 0.9375 0.7441860465116279
1640 180 1.3310329914093018 0.96875 0.6666666666666666
1645 180 0.860778272151947 1.0 0.7633136094674556
1650 180 1.118299126625061 1.0 0.6818181818181818
1655 180 0.8316289782524109 0.96875 0.7784090909090909
1660 180 0.8396706581115723 0.96875 0.8128654970760234
1665 180 1.3668744564056396 0.96875 0.6793478260869565
1670 180 0.9910237789154053 0.9375 0.7430555555555556
1675 180 1.167981743812561 0.96875 0.7344632768361582
1680 180 0.8949767351150513 0.96875 0.7474226804123711
1685 180 1.0360218286514282 1.0 0.719806763285024

In [10]:
#第14章/定义工具函数，tensor转换为字符串
def tensor_to_str(tensor):
    #转换为list格式
    tensor = tensor.tolist()
    #过滤掉PAD
    tensor = [i for i in tensor if i != vocab.loc['<PAD>'].token]
    #转换为词
    tensor = [vocab_r.loc[i].word for i in tensor]
    #转换为字符串
    return ' '.join(tensor)


tensor_to_str(sent[0])

'<SOS> among three major candidates schwarzenegger is wining the battle for independents and crossover voters <EOS> schwarzenegger picks up more independents and crossover voters than bustamante <EOS>'

In [11]:
#第14章/定义工具函数，打印预测结果
def print_predict(same, pred_same, replace_sent, sent, pred_sent, replace):
    #输出same预测结果
    same = same[0].item()
    pred_same = pred_same.argmax(dim=1)[0].item()
    print('same=', same, 'pred_same=', pred_same)
    print()

    #输出句子替换词的预测结果
    replace_sent = tensor_to_str(replace_sent[0])
    sent = tensor_to_str(sent[0][replace[0]])
    pred_sent = tensor_to_str(pred_sent.argmax(dim=2)[0][replace[0]])
    print('replace_sent=', replace_sent)
    print()
    print('sent=', sent)
    print()
    print('pred_sent=', pred_sent)
    print()
    print('-------------------------------------')


print_predict(same, torch.randn(32, 2), replace_sent, sent,
              torch.randn(32, 72, 100), replace)

same= 0 pred_same= 1

replace_sent= <SOS> among three major candidates schwarzenegger is wining the battle for independents and crossover <MASK> <EOS> schwarzenegger picks up more independents <MASK> crossover <MASK> than bustamante <EOS>

sent= voters and voters

pred_sent= before hanging distorting

-------------------------------------


In [12]:
#第14章/测试
def test():
    model.eval()
    correct_same = 0
    total_same = 0
    correct_sent = 0
    total_sent = 0
    for i, (same, sent, seg) in enumerate(loader):
        #测试5个批次
        if i == 5:
            break
        #same = [b]
        #sent = [b, 72]
        #seg = [b, 72]

        #随机替换x中的某些字符,replace为是否被操作过的矩阵,这里的操作包括不替换
        #replace_sent = [b, 72]
        #replace = [b, 72]
        replace_sent, replace = random_replace(sent)

        #模型计算
        #[b, 72],[b, 72] -> [b, 2],[b, 72, V]
        with torch.no_grad():
            pred_same, pred_sent = model(replace_sent, seg)

        #输出预测结果
        print_predict(same, pred_same, replace_sent, sent, pred_sent, replace)

        #只把被操作过的字提取出来
        #[b, 72, V] -> [replace, V]
        pred_sent = pred_sent[replace]

        #把被操作之前的字取出来
        #[b, 72] -> [replace]
        sent = sent[replace]

        #计算same预测正确率
        pred_same = pred_same.argmax(dim=1)
        correct_same += (same == pred_same).sum().item()
        total_same += len(same)

        #计算替换词预测正确率
        pred_sent = pred_sent.argmax(dim=1)
        correct_sent += (sent == pred_sent).sum().item()
        total_sent += len(sent)

    print(correct_same / total_same)
    print(correct_sent / total_sent)


test()

same= 1 pred_same= 1

replace_sent= <SOS> this individual s lawyers are trying <MASK> obtain from the court a free pass to download or upload music online illegally <EOS> her lawyers are trying to obtain a <MASK> pass <MASK> <MASK> <MASK> upload <MASK> <MASK> line illegally <EOS>

sent= to free to download or upload music on

pred_sent= to free or download or upload music or

-------------------------------------
same= 1 pred_same= 1

replace_sent= <SOS> a federal <MASK> court yesterday reinstated <MASK> charges against a san diego student accused of lying about his association <MASK> <NUM> <NUM> hijackers <EOS> a u s appeals court in <MASK> york <MASK> perjury charges against a grossmont college student accused of lying about his knowledge of two of the sept <NUM> hijackers <EOS>

sent= appeals perjury his with new reinstated knowledge hijackers

pred_sent= appeals perjury his with new reinstated knowledge hijackers

-------------------------------------
same= 0 pred_same= 0

replace_