In [1]:
import torch
from chinesebert import ChineseBertForMaskedLM, ChineseBertTokenizerFast, ChineseBertConfig

pretrained_model_name = "junnyu/ChineseBERT-base"

tokenizer = ChineseBertTokenizerFast.from_pretrained(pretrained_model_name)
chinese_bert = ChineseBertForMaskedLM.from_pretrained(pretrained_model_name)

text = "北京是[MASK]国的首都。"
inputs = tokenizer(text, return_tensors="pt")
print(inputs)
maskpos = 4

with torch.no_grad():
    o = chinese_bert(**inputs)
    value, index = o.logits.softmax(-1)[0, maskpos].topk(10)

pred_tokens = tokenizer.convert_ids_to_tokens(index.tolist())
pred_values = value.tolist()

outputs = []
for t, p in zip(pred_tokens, pred_values):
    outputs.append(f"{t}|{round(p,4)}")
print(outputs)

# base  ['中|0.711', '我|0.2488', '祖|0.016', '法|0.0057', '美|0.0048', '全|0.0042', '韩|0.0015', '英|0.0011', '两|0.0008', '王|0.0006']
# large ['中|0.8341', '我|0.1479', '祖|0.0157', '全|0.0007', '国|0.0005', '帝|0.0001', '该|0.0001', '法|0.0001', '一|0.0001', '咱|0.0001']


The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'BertTokenizer'. 
The class this function is called from is 'ChineseBertTokenizerFast'.


{'input_ids': tensor([[ 101, 1266,  776, 3221,  103, 1744, 4638, 7674, 6963,  511,  102]]), 'pinyin_ids': tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  7, 10, 14,  3,  0,  0,  0,  0, 15, 14,
         19, 12,  1,  0,  0,  0, 24, 13, 14,  4,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0, 12, 26, 20,  2,  0,  0,  0,  0,  9, 10,  5,  0,  0,  0,
          0,  0, 24, 13, 20, 26,  3,  0,  0,  0,  9, 26,  1,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
['中|0.711', '我|0.2488', '祖|0.016', '法|0.0057', '美|0.0048', '全|0.0042', '韩|0.0015', '英|0.0011', '两|0.0008', '王|0.0006']


In [5]:
from transformers import DataCollatorWithPadding

from chinesebert import ChineseBertTokenizerFast

tokenizer = ChineseBertTokenizerFast.from_pretrained("junnyu/ChineseBERT-base")
collate_fn = DataCollatorWithPadding(tokenizer)
textlist = ["弗洛伊德的悲剧凸显了在美国和世界范围", "紧迫性和重要性，国际社会必须立", "那些存在严重种族主义、种族歧视", "中方对巴基斯坦开普省发"]
batch_list = [tokenizer(t) for t in textlist]
for i in textlist:
  print(tokenizer(i))
  break
batch = collate_fn(batch_list)
#print(batch.to("cuda:0"))

for i, e in enumerate(batch):

  print(e["pinyin_ids"] == tokenizer(textlist[i])["pinyin_ids"])



The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'BertTokenizer'. 
The class this function is called from is 'ChineseBertTokenizerFast'.


{'input_ids': [101, 2472, 3821, 823, 2548, 4638, 2650, 1196, 1137, 3227, 749, 1762, 5401, 1744, 1469, 686, 4518, 5745, 1741, 102], 'pinyin_ids': [0, 0, 0, 0, 0, 0, 0, 0, 11, 26, 2, 0, 0, 0, 0, 0, 17, 26, 20, 4, 0, 0, 0, 0, 30, 14, 1, 0, 0, 0, 0, 0, 9, 10, 2, 0, 0, 0, 0, 0, 9, 10, 5, 0, 0, 0, 0, 0, 7, 10, 14, 1, 0, 0, 0, 0, 15, 26, 4, 0, 0, 0, 0, 0, 25, 26, 1, 0, 0, 0, 0, 0, 29, 14, 6, 19, 3, 0, 0, 0, 17, 10, 5, 0, 0, 0, 0, 0, 31, 6, 14, 4, 0, 0, 0, 0, 18, 10, 14, 3, 0, 0, 0, 0, 12, 26, 20, 2, 0, 0, 0, 0, 13, 10, 2, 0, 0, 0, 0, 0, 24, 13, 14, 4, 0, 0, 0, 0, 15, 14, 10, 4, 0, 0, 0, 0, 11, 6, 19, 4, 0, 0, 0, 0, 28, 10, 14, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}


TypeError: string indices must be integers

In [None]:
"""[ 0,  0,  0,  0,  0,  0,  0,  0, 11, 26,  2,  0,  0,  0,  0,  0, 17, 26,
         20,  4,  0,  0,  0,  0, 30, 14,  1,  0,  0,  0,  0,  0,  9, 10,  2,  0,
          0,  0,  0,  0,  9, 10,  5,  0,  0,  0,  0,  0,  7, 10, 14,  1,  0,  0,
          0,  0, 15, 26,  4,  0,  0,  0,  0,  0, 25, 26,  1,  0,  0,  0,  0,  0,
         29, 14,  6, 19,  3,  0,  0,  0, 17, 10,  5,  0,  0,  0,  0,  0, 31,  6,
         14,  4,  0,  0,  0,  0, 18, 10, 14,  3,  0,  0,  0,  0, 12, 26, 20,  2,
          0,  0,  0,  0, 13, 10,  2,  0,  0,  0,  0,  0, 24, 13, 14,  4,  0,  0,
          0,  0, 15, 14, 10,  4,  0,  0,  0,  0, 11,  6, 19,  4,  0,  0,  0,  0,
         28, 10, 14,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],"""

In [5]:
import torch
import torch.nn as nn

In [9]:
input1 = torch.randn(100, 128)
input2 = torch.randn(100, 128)
cos = nn.CosineSimilarity(dim=0, eps=1e-6)
output = cos(input1, input2)

In [13]:
a = cos(torch.tensor([0.0]), torch.tensor([0.0]))

In [17]:
torch.mean(torch.tensor([0.0]))

tensor(0.)

In [1]:
from transformers import AutoTokenizer

tokenizer_model_name_path="hfl/chinese-roberta-wwm-ext"

tokenizer = AutoTokenizer.from_pretrained(tokenizer_model_name_path)


In [18]:
tokenizer.decode([103])

'[MASK]'

In [19]:
import pickle

with open("tmp.pkl", "rb") as f:
    score = pickle.load(f)


In [31]:
print(score.shape)

torch.Size([52, 21128])


In [37]:
topk3 = score.topk(3, dim=1)[-1]

In [38]:
topk3

tensor([[2769,  511,  872],
        [ 872, 2644, 2769],
        [1962,  812, 2769],
        [ 106,  131,  117],
        [2769, 6443, 2218],
        [3221, 1373, 4263],
        [2476, 2769, 2484],
        [4263,  136, 2695],
        [3152, 2769,  872],
        [ 511, 8013,  117],
        [2769,  511,  872],
        [ 511, 8013,  106],
        [ 511,  106,  136],
        [ 511,  106, 8013],
        [ 511,  106, 8013],
        [ 511, 8013,  106],
        [ 511,  106, 8013],
        [ 511, 8013,  106],
        [ 511, 8013,  106],
        [ 511, 8013,  106],
        [ 106, 2769,  511],
        [ 511, 8013,  106],
        [ 511, 8013,  106],
        [ 511, 8013,  106],
        [ 511, 8013,  106],
        [ 511, 8013,  106],
        [2769,  872, 6468],
        [ 511, 8013,  106],
        [2769,  872, 6468],
        [ 872, 2769, 6468],
        [2769,  872, 6468],
        [ 511, 8013,  106],
        [2769,  872, 6468],
        [ 136,  511,  106],
        [ 511,  106, 8013],
        [ 511, 8013,