In [1]:
import os
from transformers import BertTokenizer, BertModel

In [2]:
import transformers
print(transformers.__version__)

3.1.0


In [3]:
basedir = "data/bert-base-chinese"

In [4]:
os.listdir(basedir)

['config.json', 'pytorch_model.bin', 'vocab.txt']

In [5]:
tokernizer = BertTokenizer.from_pretrained(basedir)

In [6]:
model = BertModel.from_pretrained(basedir)

In [7]:
inputs = tokernizer("我们来试试牛逼的bert模型吧", return_tensors="pt")
# inputs = tokernizer.tokenize("我们来试试牛逼的bert模型吧")

In [8]:
inputs

{'input_ids': tensor([[ 101, 2769,  812, 3341, 6407, 6407, 4281, 6873, 4638, 8815, 8716, 3563,
         1798, 1416,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [9]:
tokernizer.decode(inputs["input_ids"].data.cpu().numpy().reshape(-1))

'[CLS] 我 们 来 试 试 牛 逼 的 bert 模 型 吧 [SEP]'

In [10]:
outputs = model(**inputs)

In [11]:
outputs[0].shape

torch.Size([1, 15, 768])

In [12]:
outputs[-1].shape

torch.Size([1, 768])

##  输入两个句子

In [13]:
inputs2 = tokernizer("我们来试试牛逼的BERT模型吧", "听说BERT模型吊炸天！", return_tensors="pt")

In [14]:
inputs2

{'input_ids': tensor([[ 101, 2769,  812, 3341, 6407, 6407, 4281, 6873, 4638, 8815, 8716, 3563,
         1798, 1416,  102, 1420, 6432, 8815, 8716, 3563, 1798, 1396, 4156, 1921,
         8013,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1]])}

In [15]:
tokernizer.decode(inputs2["input_ids"].data.cpu().numpy().reshape(-1))

'[CLS] 我 们 来 试 试 牛 逼 的 bert 模 型 吧 [SEP] 听 说 bert 模 型 吊 炸 天 ！ [SEP]'

In [16]:
seq_outputs, pooled_outputs = model(**inputs2)

In [17]:
seq_outputs.shape

torch.Size([1, 26, 768])

In [18]:
pooled_outputs.shape

torch.Size([1, 768])

In [19]:
seq_outputs

tensor([[[ 0.4671, -0.2775, -0.8004,  ..., -0.3845, -0.7908, -0.0454],
         [ 0.2140, -0.4219, -0.1619,  ..., -0.5265, -0.9865,  0.0982],
         [ 0.4882, -1.4815, -0.8973,  ...,  0.8416,  0.2024,  0.2813],
         ...,
         [ 0.6722, -0.2272, -0.4810,  ..., -0.1517,  0.6246, -0.2592],
         [-0.7090, -0.2184,  0.4602,  ...,  0.0310,  0.0080, -0.7972],
         [-0.0265, -0.4677, -0.8057,  ..., -0.2529, -0.6123, -0.5260]]],
       grad_fn=<AddcmulBackward>)

In [20]:
seq_outputs.mean(1)  # 词向量平均

tensor([[ 4.2929e-02, -3.0075e-01, -5.7258e-01,  5.3962e-01,  5.7433e-01,
         -5.8809e-01,  3.2765e-01, -1.3403e-01, -4.5887e-01,  3.9097e-01,
         -1.1773e-01, -2.9848e-01, -1.5606e-02,  3.5034e-01, -5.9695e-02,
         -1.5067e-01, -4.8329e-02,  1.5288e-01,  2.7167e-02,  2.4180e-01,
         -2.2730e-01,  3.1727e-01, -5.5646e-02,  7.9556e-02, -2.0635e-01,
         -3.5557e-01, -1.6777e-01, -3.2059e-01,  4.0694e-01, -6.0094e-01,
         -4.8687e-01, -4.6362e-01, -1.1115e-01,  2.7846e-01,  2.8935e-01,
         -4.6631e-01, -4.4169e-01, -9.3158e-02, -4.9807e-01, -3.9810e-01,
         -2.6256e-01, -1.6920e-01, -4.5223e-01,  7.4979e-02, -4.0159e-02,
          1.2917e-01,  3.0358e-01,  4.3336e-01,  8.9806e-02,  3.5380e-01,
          1.2196e-01,  8.4199e+00, -6.2707e-02, -4.2542e-01, -6.4139e-01,
          6.8191e-01,  1.0297e+00, -2.2472e-01, -7.0235e-02, -4.9721e-01,
         -1.0397e-01,  2.4219e-01, -9.3619e-02, -3.3611e-01, -7.0822e-03,
         -1.5673e-01, -7.2997e-02,  4.

In [21]:
seq_outputs[0, :26, 0]

tensor([ 0.4671,  0.2140,  0.4882, -0.2431, -0.5351, -0.7403,  0.1061,  0.8377,
        -0.6703,  0.5808,  1.6504, -0.1127, -0.5356,  0.2803, -0.0265,  0.3146,
        -0.6757,  0.2130,  1.3798, -0.2864, -0.7646,  0.0787, -0.8409,  0.6722,
        -0.7090, -0.0265], grad_fn=<SelectBackward>)

In [22]:
seq_outputs[0, :26, 0].mean()

tensor(0.0429, grad_fn=<MeanBackward0>)

In [23]:
seq_outputs[0, :26, 1].mean()

tensor(-0.3008, grad_fn=<MeanBackward0>)

In [24]:
for idx, x in enumerate(model.parameters(),1):
    print(idx, x.size())

1 torch.Size([21128, 768])
2 torch.Size([512, 768])
3 torch.Size([2, 768])
4 torch.Size([768])
5 torch.Size([768])
6 torch.Size([768, 768])
7 torch.Size([768])
8 torch.Size([768, 768])
9 torch.Size([768])
10 torch.Size([768, 768])
11 torch.Size([768])
12 torch.Size([768, 768])
13 torch.Size([768])
14 torch.Size([768])
15 torch.Size([768])
16 torch.Size([3072, 768])
17 torch.Size([3072])
18 torch.Size([768, 3072])
19 torch.Size([768])
20 torch.Size([768])
21 torch.Size([768])
22 torch.Size([768, 768])
23 torch.Size([768])
24 torch.Size([768, 768])
25 torch.Size([768])
26 torch.Size([768, 768])
27 torch.Size([768])
28 torch.Size([768, 768])
29 torch.Size([768])
30 torch.Size([768])
31 torch.Size([768])
32 torch.Size([3072, 768])
33 torch.Size([3072])
34 torch.Size([768, 3072])
35 torch.Size([768])
36 torch.Size([768])
37 torch.Size([768])
38 torch.Size([768, 768])
39 torch.Size([768])
40 torch.Size([768, 768])
41 torch.Size([768])
42 torch.Size([768, 768])
43 torch.Size([768])
44 torch.S

In [25]:
for x in model.state_dict():
    print(x)

embeddings.position_ids
embeddings.word_embeddings.weight
embeddings.position_embeddings.weight
embeddings.token_type_embeddings.weight
embeddings.LayerNorm.weight
embeddings.LayerNorm.bias
encoder.layer.0.attention.self.query.weight
encoder.layer.0.attention.self.query.bias
encoder.layer.0.attention.self.key.weight
encoder.layer.0.attention.self.key.bias
encoder.layer.0.attention.self.value.weight
encoder.layer.0.attention.self.value.bias
encoder.layer.0.attention.output.dense.weight
encoder.layer.0.attention.output.dense.bias
encoder.layer.0.attention.output.LayerNorm.weight
encoder.layer.0.attention.output.LayerNorm.bias
encoder.layer.0.intermediate.dense.weight
encoder.layer.0.intermediate.dense.bias
encoder.layer.0.output.dense.weight
encoder.layer.0.output.dense.bias
encoder.layer.0.output.LayerNorm.weight
encoder.layer.0.output.LayerNorm.bias
encoder.layer.1.attention.self.query.weight
encoder.layer.1.attention.self.query.bias
encoder.layer.1.attention.self.key.weight
encoder.la

In [27]:
#!pip install transformers==3.1.0 -i https://mirrors.aliyun.com/pypi/simple

In [28]:
# !python -V