In [1]:
import torch
from datasets import load_dataset
from datasets import load_from_disk
import random


#定义数据集
class Dataset(torch.utils.data.Dataset):
    def __init__(self, split):
#         dataset = load_dataset(path='seamew/ChnSentiCorp', split=split)

#         def f(data):
#             return len(data['text']) > 40

#         self.dataset = dataset.filter(f)
        dataset = load_from_disk('/Users/zard/Documents/nlp002/Huggingface_Toturials/data/ChnSentiCorp')
        self.dataset = dataset[split]

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        text = self.dataset[i]['text']

        #切分一句话为前半句和后半句
        sentence1 = text[:20]
        sentence2 = text[20:40]
        label = 0

        #有一半的概率把后半句替换为一句无关的话
        if random.randint(0, 1) == 0:
            j = random.randint(0, len(self.dataset) - 1)
            sentence2 = self.dataset[j]['text'][20:40]
            label = 1

        return sentence1, sentence2, label


dataset = Dataset('train')

sentence1, sentence2, label = dataset[0]

len(dataset), sentence1, sentence2, label

(9600, '选择珠江花园的原因就是方便，有电动扶梯直', '他连锁店一般免费).床垫上只有很薄的床单', 1)

In [2]:
from transformers import BertTokenizer

#加载字典和分词工具
token = BertTokenizer.from_pretrained('bert-base-chinese')

token

PreTrainedTokenizer(name_or_path='bert-base-chinese', vocab_size=21128, model_max_len=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})

In [3]:
def collate_fn(data):
    
#     print("数据data:", data)
    
    sents = [i[:2] for i in data]
    labels = [i[2] for i in data]
    
#     print("句子sents:", sents)
#     print("标签labels:", labels)
    
    #编码
    data = token.batch_encode_plus(batch_text_or_text_pairs=sents,
                                   truncation=True,
                                   padding='max_length',
                                   max_length=45,
                                   return_tensors='pt',
                                   return_length=True,
                                   add_special_tokens=True)

    #input_ids:编码之后的数字
    #attention_mask:是补零的位置是0,其他位置是1
    #token_type_ids:第一个句子和特殊符号的位置是0,第二个句子的位置是1
    input_ids = data['input_ids']
    attention_mask = data['attention_mask']
    token_type_ids = data['token_type_ids']
    labels = torch.LongTensor(labels)

#     print("张量labels:", labels)
    
    #print(data['length'], data['length'].max())

    return input_ids, attention_mask, token_type_ids, labels


#数据加载器
loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=8,
                                     collate_fn=collate_fn,
                                     shuffle=True,
                                     drop_last=True)

for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
    break

print(len(loader))
print(token.decode(input_ids[0]))
print(len(input_ids[0]))
input_ids.shape, attention_mask.shape, token_type_ids.shape, labels

1200
[CLS] 房 间 升 级 到 了 行 政 楼 层 房 间 很 大 ， 很 舒 服 ， [SEP] 发 现 少 了 一 页 ， 后 来 我 申 请 换 货 ， 结 果 当 当 网 [SEP] [PAD] [PAD] [PAD]
45


(torch.Size([8, 45]),
 torch.Size([8, 45]),
 torch.Size([8, 45]),
 tensor([1, 0, 1, 0, 0, 1, 1, 0]))

In [4]:
from transformers import BertModel

#加载预训练模型
pretrained = BertModel.from_pretrained('bert-base-chinese')

#不训练,不需要计算梯度
for param in pretrained.parameters():
    param.requires_grad_(False)

#模型试算
out = pretrained(input_ids=input_ids,
           attention_mask=attention_mask,
           token_type_ids=token_type_ids)

out.last_hidden_state.shape

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


torch.Size([8, 45, 768])

In [5]:
#定义下游任务模型
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = torch.nn.Linear(768, 2)

    def forward(self, input_ids, attention_mask, token_type_ids):
        with torch.no_grad():
            out = pretrained(input_ids=input_ids,
                             attention_mask=attention_mask,
                             token_type_ids=token_type_ids)
            
#         print("我的out.last_hidden_state.shape:", out.last_hidden_state.shape)
#         print("out.last_hidden_state[:, 0].shape", out.last_hidden_state[:, 0].shape)
        out = self.fc(out.last_hidden_state[:, 0])
#         print("out.shape:", out.shape)
        out = out.softmax(dim=1)
        return out

model = Model()

model(input_ids=input_ids,
      attention_mask=attention_mask,
      token_type_ids=token_type_ids).shape

torch.Size([8, 2])

In [6]:
from transformers import AdamW

#训练
optimizer = AdamW(model.parameters(), lr=5e-4)

criterion = torch.nn.CrossEntropyLoss()

model.train()

print("loader.bath_size:", loader.batch_size)

for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
    
    out = model(input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids)
    
    print("out.shape:", out.shape)
    print("labels.shape:", labels.shape)
    
    print("out:", out)
    print("labels:", labels)
    
    loss = criterion(out, labels)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if i % 5 == 0:
        out = out.argmax(dim=1)
        accuracy = (out == labels).sum().item() / len(labels)
        print(i, loss.item(), accuracy)

    if i == 300:
        break



loader.bath_size: 8
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[0.6311, 0.3689],
        [0.3689, 0.6311],
        [0.4823, 0.5177],
        [0.4450, 0.5550],
        [0.6235, 0.3765],
        [0.3645, 0.6355],
        [0.4488, 0.5512],
        [0.5474, 0.4526]], grad_fn=<SoftmaxBackward0>)
labels: tensor([1, 0, 1, 0, 1, 0, 0, 1])
0 0.780038595199585 0.125
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[0.3666, 0.6334],
        [0.6294, 0.3706],
        [0.4043, 0.5957],
        [0.6068, 0.3932],
        [0.3357, 0.6643],
        [0.7260, 0.2740],
        [0.3697, 0.6303],
        [0.5105, 0.4895]], grad_fn=<SoftmaxBackward0>)
labels: tensor([0, 0, 1, 1, 0, 1, 0, 1])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[0.6456, 0.3544],
        [0.5736, 0.4264],
        [0.5882, 0.4118],
        [0.4642, 0.5358],
        [0.5506, 0.4494],
        [0.4161, 0.5839],
        [0.5553, 0.4447],
        [0.6583, 0.34

out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[0.3535, 0.6465],
        [0.6163, 0.3837],
        [0.4512, 0.5488],
        [0.0368, 0.9632],
        [0.7113, 0.2887],
        [0.2193, 0.7807],
        [0.5210, 0.4790],
        [0.1015, 0.8985]], grad_fn=<SoftmaxBackward0>)
labels: tensor([1, 0, 1, 1, 0, 1, 1, 1])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[0.5800, 0.4200],
        [0.1077, 0.8923],
        [0.0201, 0.9799],
        [0.8769, 0.1231],
        [0.5773, 0.4227],
        [0.3442, 0.6558],
        [0.7960, 0.2040],
        [0.7230, 0.2770]], grad_fn=<SoftmaxBackward0>)
labels: tensor([1, 1, 1, 0, 1, 1, 0, 1])
25 0.5707626938819885 0.625
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[0.6485, 0.3515],
        [0.8040, 0.1960],
        [0.7901, 0.2099],
        [0.1172, 0.8828],
        [0.8755, 0.1245],
        [0.0856, 0.9144],
        [0.6688, 0.3312],
        [0.9387, 0.0613]], grad_fn=<Sof

out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[0.9639, 0.0361],
        [0.9131, 0.0869],
        [0.0248, 0.9752],
        [0.1114, 0.8886],
        [0.0100, 0.9900],
        [0.9576, 0.0424],
        [0.0183, 0.9817],
        [0.9718, 0.0282]], grad_fn=<SoftmaxBackward0>)
labels: tensor([0, 0, 1, 1, 1, 0, 1, 0])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[0.9429, 0.0571],
        [0.1302, 0.8698],
        [0.0116, 0.9884],
        [0.0805, 0.9195],
        [0.0378, 0.9622],
        [0.9569, 0.0431],
        [0.8299, 0.1701],
        [0.8278, 0.1722]], grad_fn=<SoftmaxBackward0>)
labels: tensor([0, 1, 1, 1, 1, 0, 0, 0])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[0.0969, 0.9031],
        [0.9793, 0.0207],
        [0.8719, 0.1281],
        [0.0041, 0.9959],
        [0.0276, 0.9724],
        [0.1404, 0.8596],
        [0.4956, 0.5044],
        [0.5637, 0.4363]], grad_fn=<SoftmaxBackward0>)
labels: tens

out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[0.0196, 0.9804],
        [0.8219, 0.1781],
        [0.2170, 0.7830],
        [0.9461, 0.0539],
        [0.8504, 0.1496],
        [0.0051, 0.9949],
        [0.8765, 0.1235],
        [0.7815, 0.2185]], grad_fn=<SoftmaxBackward0>)
labels: tensor([1, 0, 1, 0, 0, 1, 1, 0])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[0.4235, 0.5765],
        [0.1167, 0.8833],
        [0.9570, 0.0430],
        [0.7675, 0.2325],
        [0.1534, 0.8466],
        [0.9959, 0.0041],
        [0.3927, 0.6073],
        [0.0047, 0.9953]], grad_fn=<SoftmaxBackward0>)
labels: tensor([0, 1, 0, 1, 1, 0, 1, 1])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[0.0988, 0.9012],
        [0.9496, 0.0504],
        [0.9061, 0.0939],
        [0.0253, 0.9747],
        [0.1510, 0.8490],
        [0.8343, 0.1657],
        [0.2496, 0.7504],
        [0.9299, 0.0701]], grad_fn=<SoftmaxBackward0>)
labels: tens

out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[0.8813, 0.1187],
        [0.0221, 0.9779],
        [0.9469, 0.0531],
        [0.9872, 0.0128],
        [0.9935, 0.0065],
        [0.8930, 0.1070],
        [0.8724, 0.1276],
        [0.5300, 0.4700]], grad_fn=<SoftmaxBackward0>)
labels: tensor([0, 1, 0, 0, 0, 0, 0, 1])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[9.3476e-01, 6.5242e-02],
        [9.7207e-01, 2.7933e-02],
        [6.9449e-01, 3.0551e-01],
        [3.7929e-04, 9.9962e-01],
        [8.1443e-01, 1.8557e-01],
        [8.0664e-04, 9.9919e-01],
        [9.4484e-01, 5.5162e-02],
        [1.4165e-02, 9.8584e-01]], grad_fn=<SoftmaxBackward0>)
labels: tensor([0, 0, 0, 1, 0, 1, 0, 1])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[0.3998, 0.6002],
        [0.0022, 0.9978],
        [0.9976, 0.0024],
        [0.0038, 0.9962],
        [0.0057, 0.9943],
        [0.7317, 0.2683],
        [0.9701, 0.0299],
   

out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[8.5718e-01, 1.4282e-01],
        [2.4556e-02, 9.7544e-01],
        [1.3200e-01, 8.6800e-01],
        [9.7830e-01, 2.1704e-02],
        [7.7972e-01, 2.2028e-01],
        [8.2485e-01, 1.7515e-01],
        [9.9202e-01, 7.9780e-03],
        [6.3456e-04, 9.9937e-01]], grad_fn=<SoftmaxBackward0>)
labels: tensor([1, 1, 1, 1, 0, 0, 0, 1])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[0.0090, 0.9910],
        [0.9615, 0.0385],
        [0.4326, 0.5674],
        [0.0076, 0.9924],
        [0.9235, 0.0765],
        [0.9801, 0.0199],
        [0.2977, 0.7023],
        [0.0317, 0.9683]], grad_fn=<SoftmaxBackward0>)
labels: tensor([1, 0, 0, 1, 1, 0, 1, 1])
120 0.5128388404846191 0.75
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[0.4508, 0.5492],
        [0.0130, 0.9870],
        [0.8748, 0.1252],
        [0.0062, 0.9938],
        [0.5578, 0.4422],
        [0.0277, 0.9723],
 

out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[0.0438, 0.9562],
        [0.6717, 0.3283],
        [0.1339, 0.8661],
        [0.9160, 0.0840],
        [0.9964, 0.0036],
        [0.0044, 0.9956],
        [0.0042, 0.9958],
        [0.9974, 0.0026]], grad_fn=<SoftmaxBackward0>)
labels: tensor([1, 1, 1, 0, 0, 1, 1, 0])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[0.9880, 0.0120],
        [0.2426, 0.7574],
        [0.7907, 0.2093],
        [0.1234, 0.8766],
        [0.4602, 0.5398],
        [0.4611, 0.5389],
        [0.0010, 0.9990],
        [0.9961, 0.0039]], grad_fn=<SoftmaxBackward0>)
labels: tensor([0, 1, 1, 1, 1, 1, 1, 0])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[0.0253, 0.9747],
        [0.3884, 0.6116],
        [0.9930, 0.0070],
        [0.8876, 0.1124],
        [0.9264, 0.0736],
        [0.9917, 0.0083],
        [0.0151, 0.9849],
        [0.1791, 0.8209]], grad_fn=<SoftmaxBackward0>)
labels: tens

out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[9.6737e-01, 3.2626e-02],
        [9.4402e-01, 5.5980e-02],
        [2.2391e-04, 9.9978e-01],
        [3.9091e-01, 6.0909e-01],
        [9.6574e-01, 3.4264e-02],
        [4.2355e-04, 9.9958e-01],
        [1.8871e-03, 9.9811e-01],
        [1.3948e-01, 8.6052e-01]], grad_fn=<SoftmaxBackward0>)
labels: tensor([0, 0, 1, 0, 0, 1, 1, 1])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[8.3623e-03, 9.9164e-01],
        [9.3655e-01, 6.3448e-02],
        [8.8228e-01, 1.1772e-01],
        [1.7451e-04, 9.9983e-01],
        [1.3563e-03, 9.9864e-01],
        [7.6721e-01, 2.3279e-01],
        [9.0563e-01, 9.4365e-02],
        [9.8229e-01, 1.7706e-02]], grad_fn=<SoftmaxBackward0>)
labels: tensor([1, 0, 0, 1, 1, 0, 0, 0])
165 0.35351327061653137 1.0
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[0.9099, 0.0901],
        [0.0011, 0.9989],
        [0.0068, 0.9932],
        [0.9906

out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[0.0863, 0.9137],
        [0.9853, 0.0147],
        [0.9842, 0.0158],
        [0.0024, 0.9976],
        [0.9046, 0.0954],
        [0.9495, 0.0505],
        [0.8933, 0.1067],
        [0.9600, 0.0400]], grad_fn=<SoftmaxBackward0>)
labels: tensor([1, 0, 0, 1, 1, 0, 1, 0])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[0.1014, 0.8986],
        [0.9962, 0.0038],
        [0.0022, 0.9978],
        [0.7272, 0.2728],
        [0.9642, 0.0358],
        [0.0017, 0.9983],
        [0.7881, 0.2119],
        [0.9281, 0.0719]], grad_fn=<SoftmaxBackward0>)
labels: tensor([1, 0, 1, 0, 0, 1, 1, 0])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[9.8295e-01, 1.7048e-02],
        [9.5991e-01, 4.0094e-02],
        [1.3052e-01, 8.6948e-01],
        [8.2734e-04, 9.9917e-01],
        [4.0818e-03, 9.9592e-01],
        [9.6742e-01, 3.2583e-02],
        [9.4142e-01, 5.8582e-02],
        [9.

out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[2.5382e-02, 9.7462e-01],
        [9.9906e-01, 9.3556e-04],
        [9.8070e-01, 1.9299e-02],
        [1.7446e-01, 8.2554e-01],
        [5.9095e-04, 9.9941e-01],
        [9.5408e-01, 4.5919e-02],
        [9.4267e-01, 5.7333e-02],
        [8.9211e-01, 1.0789e-01]], grad_fn=<SoftmaxBackward0>)
labels: tensor([1, 0, 0, 1, 1, 1, 1, 1])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[0.0015, 0.9985],
        [0.8626, 0.1374],
        [0.7337, 0.2663],
        [0.9266, 0.0734],
        [0.9986, 0.0014],
        [0.0020, 0.9980],
        [0.9912, 0.0088],
        [0.9028, 0.0972]], grad_fn=<SoftmaxBackward0>)
labels: tensor([1, 0, 0, 0, 0, 1, 0, 0])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[9.8900e-01, 1.1003e-02],
        [8.1339e-01, 1.8661e-01],
        [8.7492e-01, 1.2508e-01],
        [5.3780e-01, 4.6220e-01],
        [9.7854e-01, 2.1457e-02],
        [8.7930

out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[0.0811, 0.9189],
        [0.9794, 0.0206],
        [0.9858, 0.0142],
        [0.4380, 0.5620],
        [0.5187, 0.4813],
        [0.0014, 0.9986],
        [0.5356, 0.4644],
        [0.0071, 0.9929]], grad_fn=<SoftmaxBackward0>)
labels: tensor([1, 0, 0, 1, 1, 1, 1, 1])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[0.0209, 0.9791],
        [0.9551, 0.0449],
        [0.9968, 0.0032],
        [0.6445, 0.3555],
        [0.0017, 0.9983],
        [0.2672, 0.7328],
        [0.8061, 0.1939],
        [0.9609, 0.0391]], grad_fn=<SoftmaxBackward0>)
labels: tensor([1, 0, 0, 0, 1, 1, 0, 0])
230 0.38830429315567017 1.0
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[8.6012e-01, 1.3988e-01],
        [9.7389e-01, 2.6113e-02],
        [1.4218e-03, 9.9858e-01],
        [1.4958e-04, 9.9985e-01],
        [9.6962e-01, 3.0377e-02],
        [7.3229e-02, 9.2677e-01],
        [9.9773e-

out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[9.8854e-01, 1.1459e-02],
        [3.2115e-03, 9.9679e-01],
        [4.9563e-04, 9.9950e-01],
        [4.6006e-01, 5.3994e-01],
        [6.0456e-01, 3.9544e-01],
        [8.5236e-01, 1.4764e-01],
        [2.4685e-01, 7.5315e-01],
        [8.2548e-01, 1.7452e-01]], grad_fn=<SoftmaxBackward0>)
labels: tensor([0, 1, 1, 1, 0, 1, 0, 0])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[0.1027, 0.8973],
        [0.6198, 0.3802],
        [0.0082, 0.9918],
        [0.7711, 0.2289],
        [0.0609, 0.9391],
        [0.9750, 0.0250],
        [0.9937, 0.0063],
        [0.9761, 0.0239]], grad_fn=<SoftmaxBackward0>)
labels: tensor([1, 1, 1, 0, 1, 0, 0, 1])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[3.4877e-05, 9.9997e-01],
        [9.9211e-01, 7.8939e-03],
        [9.9884e-01, 1.1579e-03],
        [3.4242e-04, 9.9966e-01],
        [5.8792e-05, 9.9994e-01],
        [7.3149

out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[6.2990e-03, 9.9370e-01],
        [2.3151e-05, 9.9998e-01],
        [1.2921e-02, 9.8708e-01],
        [8.9459e-02, 9.1054e-01],
        [6.1166e-04, 9.9939e-01],
        [8.9253e-01, 1.0747e-01],
        [8.3926e-01, 1.6074e-01],
        [9.9479e-01, 5.2099e-03]], grad_fn=<SoftmaxBackward0>)
labels: tensor([1, 1, 1, 1, 1, 0, 1, 0])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[9.1475e-01, 8.5248e-02],
        [9.8114e-01, 1.8863e-02],
        [1.2432e-02, 9.8757e-01],
        [5.3181e-01, 4.6819e-01],
        [9.9235e-01, 7.6485e-03],
        [2.8845e-04, 9.9971e-01],
        [8.2295e-05, 9.9992e-01],
        [9.8758e-01, 1.2420e-02]], grad_fn=<SoftmaxBackward0>)
labels: tensor([0, 0, 1, 0, 0, 1, 1, 0])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[8.2174e-01, 1.7826e-01],
        [9.9284e-01, 7.1560e-03],
        [7.8991e-03, 9.9210e-01],
        [9.9950e-01

out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[0.0052, 0.9948],
        [0.2826, 0.7174],
        [0.9269, 0.0731],
        [0.2430, 0.7570],
        [0.3812, 0.6188],
        [0.2966, 0.7034],
        [0.9185, 0.0815],
        [0.8884, 0.1116]], grad_fn=<SoftmaxBackward0>)
labels: tensor([1, 1, 0, 1, 1, 0, 1, 0])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[2.4807e-05, 9.9998e-01],
        [6.3547e-01, 3.6453e-01],
        [9.5378e-01, 4.6216e-02],
        [9.9910e-01, 9.0184e-04],
        [1.7027e-01, 8.2973e-01],
        [4.1965e-01, 5.8035e-01],
        [1.0594e-01, 8.9406e-01],
        [1.0078e-02, 9.8992e-01]], grad_fn=<SoftmaxBackward0>)
labels: tensor([1, 0, 0, 0, 1, 1, 0, 1])
out.shape: torch.Size([8, 2])
labels.shape: torch.Size([8])
out: tensor([[3.6083e-01, 6.3917e-01],
        [9.9531e-01, 4.6942e-03],
        [7.1018e-01, 2.8982e-01],
        [1.9214e-04, 9.9981e-01],
        [1.7513e-01, 8.2487e-01],
        [9.8801

In [10]:
#测试
def test():
    model.eval()
    correct = 0
    total = 0

    loader_test = torch.utils.data.DataLoader(dataset=Dataset('test'),
                                              batch_size=32,
                                              collate_fn=collate_fn,
                                              shuffle=True,
                                              drop_last=True)

    for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader_test):
        if i == 5:
            break

        print(i)

        with torch.no_grad():
            out = model(input_ids=input_ids,
                        attention_mask=attention_mask,
                        token_type_ids=token_type_ids)

        pred = out.argmax(dim=1)

        correct += (pred == labels).sum().item()
        total += len(labels)

    print(correct / total)


test()

0
1
2
3
4
0.925
