In [2]:
import torch
from datasets import load_dataset
from datasets import load_from_disk
import random


#定义数据集
class Dataset(torch.utils.data.Dataset):
    def __init__(self, split):
#         dataset = load_dataset(path='seamew/ChnSentiCorp', split=split)

#         def f(data):
#             return len(data['text']) > 40

#         self.dataset = dataset.filter(f)
        dataset = load_from_disk('/Users/zard/Documents/nlp002/Huggingface_Toturials/data/ChnSentiCorp')
        self.dataset = dataset[split]

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        text = self.dataset[i]['text']

        #切分一句话为前半句和后半句
        sentence1 = text[:20]
        sentence2 = text[20:40]
        label = 0

        #有一半的概率把后半句替换为一句无关的话
        if random.randint(0, 1) == 0:
            j = random.randint(0, len(self.dataset) - 1)
            sentence2 = self.dataset[j]['text'][20:40]
            label = 1

        return sentence1, sentence2, label


dataset = Dataset('train')

sentence1, sentence2, label = dataset[0]

len(dataset), sentence1, sentence2, label

(9600, '选择珠江花园的原因就是方便，有电动扶梯直', 'IE浏览器,触摸屏烫手,上面可以煎鸡蛋,', 1)

In [3]:
from transformers import BertTokenizer

#加载字典和分词工具
token = BertTokenizer.from_pretrained('bert-base-chinese')

token

PreTrainedTokenizer(name_or_path='bert-base-chinese', vocab_size=21128, model_max_len=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})

In [9]:
def collate_fn(data):
    
#     print("数据data:", data)
    
    sents = [i[:2] for i in data]
    labels = [i[2] for i in data]
    
#     print("句子sents:", sents)
#     print("标签labels:", labels)
    
    #编码
    data = token.batch_encode_plus(batch_text_or_text_pairs=sents,
                                   truncation=True,
                                   padding='max_length',
                                   max_length=45,
                                   return_tensors='pt',
                                   return_length=True,
                                   add_special_tokens=True)

    #input_ids:编码之后的数字
    #attention_mask:是补零的位置是0,其他位置是1
    #token_type_ids:第一个句子和特殊符号的位置是0,第二个句子的位置是1
    input_ids = data['input_ids']
    attention_mask = data['attention_mask']
    token_type_ids = data['token_type_ids']
    labels = torch.LongTensor(labels)

#     print("张量labels:", labels)
    
    #print(data['length'], data['length'].max())

    return input_ids, attention_mask, token_type_ids, labels


#数据加载器
loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=8,
                                     collate_fn=collate_fn,
                                     shuffle=True,
                                     drop_last=True)

for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
    break

print(len(loader))
print(token.decode(input_ids[0]))
print(len(input_ids[0]))
input_ids.shape, attention_mask.shape, token_type_ids.shape, labels

1200
[CLS] 酒 店 太 老 旧 了 。 走 在 地 板 上 都 有 震 动 的 感 觉 。 [SEP] 厕 所 内 的 手 纸 就 一 点 点 。 220 的 价 格 还 不 [SEP] [PAD] [PAD] [PAD] [PAD] [PAD]
45


(torch.Size([8, 45]),
 torch.Size([8, 45]),
 torch.Size([8, 45]),
 tensor([0, 0, 0, 1, 0, 1, 1, 1]))

In [12]:
from transformers import BertModel

#加载预训练模型
pretrained = BertModel.from_pretrained('bert-base-chinese')

#不训练,不需要计算梯度
for param in pretrained.parameters():
    param.requires_grad_(False)

#模型试算
out = pretrained(input_ids=input_ids,
           attention_mask=attention_mask,
           token_type_ids=token_type_ids)

out.last_hidden_state.shape

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


torch.Size([8, 45, 768])

In [24]:
#定义下游任务模型
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = torch.nn.Linear(768, 2)

    def forward(self, input_ids, attention_mask, token_type_ids):
        with torch.no_grad():
            out = pretrained(input_ids=input_ids,
                             attention_mask=attention_mask,
                             token_type_ids=token_type_ids)
            
#         print("我的out.last_hidden_state.shape:", out.last_hidden_state.shape)
#         print("out.last_hidden_state[:, 0].shape", out.last_hidden_state[:, 0].shape)
        out = self.fc(out.last_hidden_state[:, 0])
#         print("out.shape:", out.shape)
        out = out.softmax(dim=1)
        return out

model = Model()

model(input_ids=input_ids,
      attention_mask=attention_mask,
      token_type_ids=token_type_ids).shape

torch.Size([8, 2])

In [28]:
from transformers import AdamW

#训练
optimizer = AdamW(model.parameters(), lr=5e-4)

criterion = torch.nn.CrossEntropyLoss()

model.train()

print("loader.bath_size:", loader.batch_size)

for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
    
    out = model(input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids)
    
    print("out.shape:", out.shape)
    print("labels.shape:", labels.shape)
    
    print("out:", out)
    print("labels:", labels)
    
    loss = criterion(out, labels)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if i % 5 == 0:
        out = out.argmax(dim=1)
        accuracy = (out == labels).sum().item() / len(labels)
        print(i, loss.item(), accuracy)

    if i == 300:
        break



loader.bath_size: 8
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[9.9370e-01, 6.3047e-03],
        [9.5445e-01, 4.5548e-02],
        [2.8223e-06, 1.0000e+00],
        [1.0716e-02, 9.8928e-01],
        [8.3153e-01, 1.6847e-01],
        [1.4481e-04, 9.9986e-01],
        [8.5597e-04, 9.9914e-01],
        [5.7394e-03, 9.9426e-01]], grad_fn=<SoftmaxBackward0>)
labels: tensor([0, 0, 1, 1, 1, 1, 1, 1])
0 0.41370898485183716 0.875
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[0.9753, 0.0247],
        [0.9970, 0.0030],
        [0.4270, 0.5730],
        [0.0288, 0.9712],
        [0.9988, 0.0012],
        [0.9787, 0.0213],
        [0.9968, 0.0032],
        [0.9904, 0.0096]], grad_fn=<SoftmaxBackward0>)
labels: tensor([0, 0, 0, 1, 0, 1, 0, 0])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[9.8795e-01, 1.2046e-02],
        [9.9833e-01, 1.6696e-03],
        [9.3695e-07, 1.0000e+00],
        [9.7419e-01, 2.5806e-02],
 

out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[9.9960e-01, 3.9827e-04],
        [9.2600e-01, 7.4000e-02],
        [4.6591e-02, 9.5341e-01],
        [9.9783e-01, 2.1650e-03],
        [2.2093e-01, 7.7907e-01],
        [3.0614e-04, 9.9969e-01],
        [4.1456e-04, 9.9959e-01],
        [9.9836e-01, 1.6392e-03]], grad_fn=<SoftmaxBackward0>)
labels: tensor([0, 1, 1, 0, 1, 1, 1, 0])
20 0.4459918141365051 0.875
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[4.6078e-04, 9.9954e-01],
        [5.5543e-04, 9.9944e-01],
        [9.0955e-01, 9.0452e-02],
        [9.8573e-01, 1.4268e-02],
        [4.5960e-02, 9.5404e-01],
        [3.5287e-05, 9.9996e-01],
        [9.9180e-01, 8.2014e-03],
        [9.9893e-01, 1.0726e-03]], grad_fn=<SoftmaxBackward0>)
labels: tensor([1, 1, 1, 1, 1, 1, 0, 0])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[4.1439e-04, 9.9959e-01],
        [4.8987e-04, 9.9951e-01],
        [9.9421e-01, 5.78

out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[4.4829e-03, 9.9552e-01],
        [5.8607e-01, 4.1393e-01],
        [9.8678e-01, 1.3222e-02],
        [1.4057e-03, 9.9859e-01],
        [1.4824e-01, 8.5176e-01],
        [9.8256e-01, 1.7435e-02],
        [3.3749e-01, 6.6251e-01],
        [1.1104e-07, 1.0000e+00]], grad_fn=<SoftmaxBackward0>)
labels: tensor([1, 0, 0, 1, 1, 0, 0, 1])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[2.7509e-04, 9.9972e-01],
        [9.9953e-01, 4.6656e-04],
        [9.8115e-02, 9.0188e-01],
        [8.6773e-07, 1.0000e+00],
        [7.6501e-01, 2.3499e-01],
        [8.6267e-01, 1.3733e-01],
        [9.8607e-01, 1.3932e-02],
        [8.5668e-01, 1.4332e-01]], grad_fn=<SoftmaxBackward0>)
labels: tensor([1, 1, 1, 1, 0, 0, 0, 0])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[9.8255e-01, 1.7455e-02],
        [9.9735e-01, 2.6532e-03],
        [9.9633e-01, 3.6668e-03],
        [5.4465e-02

out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[9.9877e-01, 1.2302e-03],
        [9.8907e-01, 1.0927e-02],
        [5.4015e-01, 4.5985e-01],
        [3.2877e-05, 9.9997e-01],
        [1.0609e-05, 9.9999e-01],
        [9.8973e-01, 1.0272e-02],
        [9.9985e-01, 1.4988e-04],
        [9.8643e-01, 1.3569e-02]], grad_fn=<SoftmaxBackward0>)
labels: tensor([0, 0, 1, 1, 1, 0, 0, 0])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[9.9981e-01, 1.9306e-04],
        [1.4108e-02, 9.8589e-01],
        [2.2193e-03, 9.9778e-01],
        [9.9191e-01, 8.0869e-03],
        [9.9796e-01, 2.0432e-03],
        [9.9989e-01, 1.0881e-04],
        [7.5820e-05, 9.9992e-01],
        [8.3724e-01, 1.6276e-01]], grad_fn=<SoftmaxBackward0>)
labels: tensor([0, 1, 1, 0, 0, 0, 1, 1])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[9.8015e-01, 1.9851e-02],
        [7.7925e-05, 9.9992e-01],
        [2.3417e-05, 9.9998e-01],
        [3.3986e-01

out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[4.5924e-01, 5.4076e-01],
        [1.5762e-04, 9.9984e-01],
        [1.6566e-01, 8.3434e-01],
        [9.4976e-01, 5.0240e-02],
        [9.7851e-01, 2.1490e-02],
        [9.9748e-01, 2.5248e-03],
        [6.1693e-01, 3.8307e-01],
        [9.3876e-01, 6.1240e-02]], grad_fn=<SoftmaxBackward0>)
labels: tensor([0, 1, 1, 0, 0, 0, 0, 0])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[9.8295e-01, 1.7055e-02],
        [9.9818e-01, 1.8166e-03],
        [9.8813e-01, 1.1868e-02],
        [9.9072e-01, 9.2839e-03],
        [9.9989e-01, 1.1416e-04],
        [1.6869e-01, 8.3131e-01],
        [9.9891e-01, 1.0885e-03],
        [5.5871e-06, 9.9999e-01]], grad_fn=<SoftmaxBackward0>)
labels: tensor([0, 0, 1, 0, 0, 1, 0, 1])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[1.2100e-01, 8.7900e-01],
        [9.9984e-01, 1.6263e-04],
        [5.5178e-01, 4.4822e-01],
        [2.1394e-05

out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[1.6317e-05, 9.9998e-01],
        [9.9993e-01, 6.5885e-05],
        [4.0153e-02, 9.5985e-01],
        [9.8060e-01, 1.9398e-02],
        [1.0591e-04, 9.9989e-01],
        [6.5401e-07, 1.0000e+00],
        [9.8337e-01, 1.6632e-02],
        [9.6892e-01, 3.1083e-02]], grad_fn=<SoftmaxBackward0>)
labels: tensor([1, 0, 0, 0, 1, 1, 0, 0])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[4.7572e-01, 5.2428e-01],
        [5.0872e-04, 9.9949e-01],
        [9.2024e-03, 9.9080e-01],
        [9.9999e-01, 7.6993e-06],
        [2.4596e-01, 7.5404e-01],
        [9.2734e-01, 7.2662e-02],
        [6.9121e-02, 9.3088e-01],
        [9.9697e-01, 3.0332e-03]], grad_fn=<SoftmaxBackward0>)
labels: tensor([1, 1, 1, 0, 1, 0, 1, 0])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[9.9984e-01, 1.6443e-04],
        [9.9964e-01, 3.5899e-04],
        [9.8876e-01, 1.1239e-02],
        [2.1600e-06

out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[4.9270e-01, 5.0730e-01],
        [9.9995e-01, 4.8635e-05],
        [9.9829e-01, 1.7079e-03],
        [8.2968e-01, 1.7032e-01],
        [1.3856e-03, 9.9861e-01],
        [6.2112e-05, 9.9994e-01],
        [1.1254e-02, 9.8875e-01],
        [9.9068e-01, 9.3213e-03]], grad_fn=<SoftmaxBackward0>)
labels: tensor([1, 0, 0, 0, 1, 1, 1, 1])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[9.9195e-01, 8.0501e-03],
        [9.8341e-01, 1.6590e-02],
        [9.9345e-01, 6.5533e-03],
        [8.6195e-02, 9.1381e-01],
        [9.8764e-01, 1.2365e-02],
        [9.9767e-01, 2.3317e-03],
        [9.9963e-01, 3.7004e-04],
        [7.9530e-01, 2.0470e-01]], grad_fn=<SoftmaxBackward0>)
labels: tensor([0, 0, 0, 1, 0, 0, 0, 1])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[2.7944e-01, 7.2056e-01],
        [8.8373e-01, 1.1627e-01],
        [9.6823e-01, 3.1773e-02],
        [8.7131e-01

out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[9.6565e-01, 3.4349e-02],
        [9.9974e-01, 2.6195e-04],
        [9.9997e-01, 3.4373e-05],
        [7.8069e-03, 9.9219e-01],
        [9.9996e-01, 3.8664e-05],
        [1.0000e+00, 7.7794e-08],
        [9.3747e-02, 9.0625e-01],
        [9.0049e-01, 9.9508e-02]], grad_fn=<SoftmaxBackward0>)
labels: tensor([0, 0, 0, 0, 0, 0, 1, 0])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[9.9959e-01, 4.1403e-04],
        [8.4392e-04, 9.9916e-01],
        [3.1012e-01, 6.8988e-01],
        [6.1946e-05, 9.9994e-01],
        [9.5558e-05, 9.9990e-01],
        [9.9983e-01, 1.7385e-04],
        [5.3824e-02, 9.4618e-01],
        [2.2354e-02, 9.7765e-01]], grad_fn=<SoftmaxBackward0>)
labels: tensor([0, 1, 1, 1, 1, 0, 1, 1])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[8.0489e-01, 1.9511e-01],
        [9.9999e-01, 1.0376e-05],
        [1.1380e-01, 8.8620e-01],
        [9.5231e-01

out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[9.9997e-01, 3.1428e-05],
        [9.8498e-01, 1.5021e-02],
        [9.9934e-01, 6.5806e-04],
        [9.9994e-01, 5.8169e-05],
        [5.7848e-07, 1.0000e+00],
        [8.7911e-01, 1.2089e-01],
        [9.9826e-01, 1.7428e-03],
        [9.9885e-01, 1.1488e-03]], grad_fn=<SoftmaxBackward0>)
labels: tensor([0, 0, 0, 0, 1, 0, 0, 0])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[9.9803e-01, 1.9728e-03],
        [9.9934e-01, 6.5911e-04],
        [9.9884e-01, 1.1610e-03],
        [9.9097e-01, 9.0341e-03],
        [1.1784e-05, 9.9999e-01],
        [5.8523e-01, 4.1477e-01],
        [1.3636e-05, 9.9999e-01],
        [9.1244e-01, 8.7559e-02]], grad_fn=<SoftmaxBackward0>)
labels: tensor([0, 1, 1, 0, 1, 0, 1, 0])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[9.9876e-01, 1.2447e-03],
        [6.2661e-02, 9.3734e-01],
        [1.3342e-05, 9.9999e-01],
        [9.9946e-01

out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[9.8935e-01, 1.0653e-02],
        [1.0953e-06, 1.0000e+00],
        [9.9641e-01, 3.5917e-03],
        [1.4019e-02, 9.8598e-01],
        [2.5052e-01, 7.4948e-01],
        [1.6594e-04, 9.9983e-01],
        [9.9885e-01, 1.1504e-03],
        [9.9880e-01, 1.1992e-03]], grad_fn=<SoftmaxBackward0>)
labels: tensor([1, 1, 0, 1, 0, 1, 0, 0])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[4.6033e-03, 9.9540e-01],
        [9.9031e-01, 9.6871e-03],
        [9.9242e-01, 7.5844e-03],
        [9.9882e-01, 1.1809e-03],
        [6.9920e-01, 3.0080e-01],
        [2.9379e-03, 9.9706e-01],
        [2.7849e-04, 9.9972e-01],
        [9.8586e-02, 9.0141e-01]], grad_fn=<SoftmaxBackward0>)
labels: tensor([1, 0, 0, 0, 0, 1, 1, 1])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[8.1238e-01, 1.8762e-01],
        [1.6670e-07, 1.0000e+00],
        [6.5447e-01, 3.4553e-01],
        [2.1678e-03

out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[9.9999e-01, 7.4728e-06],
        [3.7307e-04, 9.9963e-01],
        [9.9916e-01, 8.3849e-04],
        [2.9763e-05, 9.9997e-01],
        [9.2884e-01, 7.1159e-02],
        [1.5282e-01, 8.4718e-01],
        [1.3992e-01, 8.6008e-01],
        [9.9722e-01, 2.7770e-03]], grad_fn=<SoftmaxBackward0>)
labels: tensor([0, 1, 0, 1, 1, 1, 1, 0])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[3.5260e-01, 6.4740e-01],
        [9.9519e-01, 4.8104e-03],
        [1.2706e-06, 1.0000e+00],
        [2.5463e-01, 7.4537e-01],
        [9.9375e-01, 6.2513e-03],
        [9.9688e-01, 3.1223e-03],
        [2.1484e-05, 9.9998e-01],
        [9.9995e-01, 5.1750e-05]], grad_fn=<SoftmaxBackward0>)
labels: tensor([0, 0, 1, 0, 0, 0, 1, 0])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[1.0321e-05, 9.9999e-01],
        [9.9920e-01, 8.0410e-04],
        [1.8324e-04, 9.9982e-01],
        [9.9546e-01

out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[9.9751e-01, 2.4884e-03],
        [9.9309e-01, 6.9149e-03],
        [4.1669e-06, 1.0000e+00],
        [5.6102e-03, 9.9439e-01],
        [3.7486e-08, 1.0000e+00],
        [2.9715e-08, 1.0000e+00],
        [9.9998e-01, 1.8628e-05],
        [9.9961e-01, 3.9349e-04]], grad_fn=<SoftmaxBackward0>)
labels: tensor([0, 0, 1, 1, 1, 1, 0, 0])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[9.2970e-01, 7.0300e-02],
        [9.9620e-01, 3.7961e-03],
        [7.0713e-03, 9.9293e-01],
        [9.7213e-01, 2.7869e-02],
        [1.2912e-05, 9.9999e-01],
        [4.8790e-05, 9.9995e-01],
        [9.9982e-01, 1.7992e-04],
        [9.9988e-01, 1.1978e-04]], grad_fn=<SoftmaxBackward0>)
labels: tensor([0, 0, 1, 1, 1, 1, 0, 0])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[9.7021e-01, 2.9795e-02],
        [4.3496e-01, 5.6504e-01],
        [9.9151e-01, 8.4942e-03],
        [9.3016e-07

out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[2.9304e-02, 9.7070e-01],
        [9.9994e-01, 5.7662e-05],
        [1.0000e+00, 1.7501e-06],
        [9.9866e-01, 1.3398e-03],
        [5.6829e-01, 4.3171e-01],
        [9.3990e-01, 6.0096e-02],
        [7.4721e-03, 9.9253e-01],
        [9.9993e-01, 6.8434e-05]], grad_fn=<SoftmaxBackward0>)
labels: tensor([1, 0, 0, 0, 0, 0, 1, 0])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[9.4831e-01, 5.1689e-02],
        [5.7170e-04, 9.9943e-01],
        [9.9342e-01, 6.5798e-03],
        [9.9994e-01, 5.6773e-05],
        [9.8778e-06, 9.9999e-01],
        [9.8900e-01, 1.0996e-02],
        [9.9964e-01, 3.5743e-04],
        [9.9958e-01, 4.1663e-04]], grad_fn=<SoftmaxBackward0>)
labels: tensor([0, 1, 0, 0, 1, 0, 0, 0])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[6.1476e-01, 3.8524e-01],
        [9.5943e-01, 4.0571e-02],
        [9.9667e-01, 3.3320e-03],
        [4.2888e-01

out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[9.9979e-01, 2.0819e-04],
        [1.4015e-02, 9.8599e-01],
        [3.6521e-02, 9.6348e-01],
        [5.0410e-04, 9.9950e-01],
        [9.3432e-01, 6.5678e-02],
        [9.0000e-01, 9.9998e-02],
        [9.1769e-04, 9.9908e-01],
        [9.9899e-01, 1.0067e-03]], grad_fn=<SoftmaxBackward0>)
labels: tensor([0, 1, 1, 1, 0, 1, 1, 0])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[9.7426e-01, 2.5744e-02],
        [9.7580e-01, 2.4204e-02],
        [3.9928e-03, 9.9601e-01],
        [7.3359e-01, 2.6641e-01],
        [9.8408e-01, 1.5917e-02],
        [9.8232e-01, 1.7683e-02],
        [9.7534e-01, 2.4664e-02],
        [9.9986e-01, 1.3546e-04]], grad_fn=<SoftmaxBackward0>)
labels: tensor([1, 0, 0, 0, 0, 1, 0, 0])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[1.8312e-04, 9.9982e-01],
        [8.2695e-01, 1.7305e-01],
        [9.8609e-01, 1.3910e-02],
        [2.6979e-03

out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[8.5677e-01, 1.4323e-01],
        [9.9868e-01, 1.3211e-03],
        [9.9997e-01, 3.3092e-05],
        [9.8645e-01, 1.3553e-02],
        [9.5331e-01, 4.6695e-02],
        [9.9989e-01, 1.1325e-04],
        [1.3220e-05, 9.9999e-01],
        [4.6413e-05, 9.9995e-01]], grad_fn=<SoftmaxBackward0>)
labels: tensor([1, 0, 0, 0, 0, 0, 1, 1])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[2.8623e-07, 1.0000e+00],
        [9.2703e-01, 7.2972e-02],
        [9.7586e-01, 2.4142e-02],
        [9.0133e-01, 9.8670e-02],
        [9.9683e-01, 3.1696e-03],
        [9.6577e-01, 3.4232e-02],
        [9.9998e-01, 2.4051e-05],
        [6.1197e-02, 9.3880e-01]], grad_fn=<SoftmaxBackward0>)
labels: tensor([1, 1, 0, 0, 0, 0, 0, 1])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[1.7952e-06, 1.0000e+00],
        [9.6122e-03, 9.9039e-01],
        [1.7538e-02, 9.8246e-01],
        [2.6042e-04

In [10]:
#测试
def test():
    model.eval()
    correct = 0
    total = 0

    loader_test = torch.utils.data.DataLoader(dataset=Dataset('test'),
                                              batch_size=32,
                                              collate_fn=collate_fn,
                                              shuffle=True,
                                              drop_last=True)

    for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader_test):
        if i == 5:
            break

        print(i)

        with torch.no_grad():
            out = model(input_ids=input_ids,
                        attention_mask=attention_mask,
                        token_type_ids=token_type_ids)

        pred = out.argmax(dim=1)

        correct += (pred == labels).sum().item()
        total += len(labels)

    print(correct / total)


test()

0
1
2
3
4
0.925
