In [2]:
import torch
from datasets import load_dataset
from datasets import load_from_disk
import random


#定义数据集
class Dataset(torch.utils.data.Dataset):
    def __init__(self, split):
#         dataset = load_dataset(path='seamew/ChnSentiCorp', split=split)

#         def f(data):
#             return len(data['text']) > 40

#         self.dataset = dataset.filter(f)
        dataset = load_from_disk('/Users/zard/Documents/nlp002/Huggingface_Toturials/data/ChnSentiCorp')
        self.dataset = dataset[split]

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        text = self.dataset[i]['text']

        #切分一句话为前半句和后半句
        sentence1 = text[:20]
        sentence2 = text[20:40]
        label = 0

        #有一半的概率把后半句替换为一句无关的话
        if random.randint(0, 1) == 0:
            j = random.randint(0, len(self.dataset) - 1)
            sentence2 = self.dataset[j]['text'][20:40]
            label = 1

        return sentence1, sentence2, label


dataset = Dataset('train')

sentence1, sentence2, label = dataset[0]

len(dataset), sentence1, sentence2, label

(9600, '选择珠江花园的原因就是方便，有电动扶梯直', 'IE浏览器,触摸屏烫手,上面可以煎鸡蛋,', 1)

In [3]:
from transformers import BertTokenizer

#加载字典和分词工具
token = BertTokenizer.from_pretrained('bert-base-chinese')

token

PreTrainedTokenizer(name_or_path='bert-base-chinese', vocab_size=21128, model_max_len=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})

In [5]:
def collate_fn(data):
    
#     print("数据data:", data)
    
    sents = [i[:2] for i in data]
    labels = [i[2] for i in data]
    
#     print("句子sents:", sents)
#     print("标签labels:", labels)
    
    #编码
    data = token.batch_encode_plus(batch_text_or_text_pairs=sents,
                                   truncation=True,
                                   padding='max_length',
                                   max_length=45,
                                   return_tensors='pt',
                                   return_length=True,
                                   add_special_tokens=True)

    #input_ids:编码之后的数字
    #attention_mask:是补零的位置是0,其他位置是1
    #token_type_ids:第一个句子和特殊符号的位置是0,第二个句子的位置是1
    input_ids = data['input_ids']
    attention_mask = data['attention_mask']
    token_type_ids = data['token_type_ids']
    labels = torch.LongTensor(labels)

#     print("张量labels:", labels)
    
    #print(data['length'], data['length'].max())

    return input_ids, attention_mask, token_type_ids, labels


#数据加载器
loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=8,
                                     collate_fn=collate_fn,
                                     shuffle=True,
                                     drop_last=True)

for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
    break

print(len(loader))
print(token.decode(input_ids[0]))
input_ids.shape, attention_mask.shape, token_type_ids.shape, labels

数据data: [('女人，尤其是女孩，心思细密，敏感浪漫，更', '容易好高务远，一想天开。这本书就是让女人', 0), ('商品描述与实物一致，确实是惠普的商用顶级', '用时间。集成的显卡稍微弱一些。', 1), ('整体感觉不如IdeaPad的本本。USB', '卡收货了。搞得我还要自己去话说维修点检测', 1), ('1 整体做工精细 2 白色的基本不会有手', '印的困扰，这是我选择白色的主要原因 3 ', 0), ('插上电源开机。发现电源按钮旁边有划痕，联', '。 地点十分方便：旁边是大马路解放路，往', 1), ('酒店的位置非常好，过了街就是繁华的南京路', '，离地铁站也特别近。酒店的硬件在4星级里', 0), ('环境一流，庭院深深，小桥流水，枯藤老树昏', '鸦。多次入住。只是早餐极其惨淡，无论中西', 0), ('临街.非常的吵.酒店的负责清洁的都是顺手', '牵羊或者无德之辈.共住了三次好像.前二次', 0)]
句子sents: [('女人，尤其是女孩，心思细密，敏感浪漫，更', '容易好高务远，一想天开。这本书就是让女人'), ('商品描述与实物一致，确实是惠普的商用顶级', '用时间。集成的显卡稍微弱一些。'), ('整体感觉不如IdeaPad的本本。USB', '卡收货了。搞得我还要自己去话说维修点检测'), ('1 整体做工精细 2 白色的基本不会有手', '印的困扰，这是我选择白色的主要原因 3 '), ('插上电源开机。发现电源按钮旁边有划痕，联', '。 地点十分方便：旁边是大马路解放路，往'), ('酒店的位置非常好，过了街就是繁华的南京路', '，离地铁站也特别近。酒店的硬件在4星级里'), ('环境一流，庭院深深，小桥流水，枯藤老树昏', '鸦。多次入住。只是早餐极其惨淡，无论中西'), ('临街.非常的吵.酒店的负责清洁的都是顺手', '牵羊或者无德之辈.共住了三次好像.前二次')]
标签labels: [0, 1, 1, 0, 1, 0, 0, 0]
张量labels: tensor([0, 1, 1, 0, 1, 0, 0, 0])
1200
[CLS] 女 人 ， 尤 其 是 女 孩 ， 心 思 细 密 ， 敏 感 浪 漫 ， 更 [SEP] 容 易 好 高 务 远 ， 一 想 天 开 。 这 本 书 就 

(torch.Size([8, 45]),
 torch.Size([8, 45]),
 torch.Size([8, 45]),
 tensor([0, 1, 1, 0, 1, 0, 0, 0]))

In [7]:
from transformers import BertModel

#加载预训练模型
pretrained = BertModel.from_pretrained('bert-base-chinese')

#不训练,不需要计算梯度
for param in pretrained.parameters():
    param.requires_grad_(False)

#模型试算
out = pretrained(input_ids=input_ids,
           attention_mask=attention_mask,
           token_type_ids=token_type_ids)

out.last_hidden_state.shape

Downloading:   0%|          | 0.00/393M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


torch.Size([8, 45, 768])

In [8]:
#定义下游任务模型
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = torch.nn.Linear(768, 2)

    def forward(self, input_ids, attention_mask, token_type_ids):
        with torch.no_grad():
            out = pretrained(input_ids=input_ids,
                             attention_mask=attention_mask,
                             token_type_ids=token_type_ids)

        out = self.fc(out.last_hidden_state[:, 0])

        out = out.softmax(dim=1)

        return out


model = Model()

model(input_ids=input_ids,
      attention_mask=attention_mask,
      token_type_ids=token_type_ids).shape

torch.Size([8, 2])

In [9]:
from transformers import AdamW

#训练
optimizer = AdamW(model.parameters(), lr=5e-4)
criterion = torch.nn.CrossEntropyLoss()

model.train()
for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
    out = model(input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids)

    loss = criterion(out, labels)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if i % 5 == 0:
        out = out.argmax(dim=1)
        accuracy = (out == labels).sum().item() / len(labels)
        print(i, loss.item(), accuracy)

    if i == 300:
        break



0 0.6829604506492615 0.5
5 0.5733257532119751 0.875
10 0.6138834953308105 0.75
15 0.4034173786640167 1.0
20 0.4893438220024109 0.875
25 0.4256744086742401 1.0
30 0.44798508286476135 0.875
35 0.49994683265686035 0.875
40 0.3902716338634491 1.0
45 0.47605112195014954 0.875
50 0.38127291202545166 1.0
55 0.46472567319869995 0.875
60 0.6686863899230957 0.5
65 0.4596948027610779 0.75
70 0.4620363712310791 0.875
75 0.4030296504497528 1.0
80 0.38339099287986755 1.0
85 0.4915534257888794 0.75
90 0.41055816411972046 0.875
95 0.578237771987915 0.75
100 0.3514157831668854 1.0
105 0.4582573473453522 0.875
110 0.41087818145751953 1.0
115 0.5620350241661072 0.75
120 0.5009777545928955 0.875
125 0.4349495768547058 0.875
130 0.4879797697067261 0.75
135 0.47779300808906555 0.875
140 0.4227628707885742 0.875
145 0.35516491532325745 1.0
150 0.5687816739082336 0.625
155 0.43313172459602356 0.875
160 0.374252051115036 0.875
165 0.3627183139324188 1.0
170 0.4084283113479614 0.875
175 0.46436724066734314 0.75

In [10]:
#测试
def test():
    model.eval()
    correct = 0
    total = 0

    loader_test = torch.utils.data.DataLoader(dataset=Dataset('test'),
                                              batch_size=32,
                                              collate_fn=collate_fn,
                                              shuffle=True,
                                              drop_last=True)

    for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader_test):
        if i == 5:
            break

        print(i)

        with torch.no_grad():
            out = model(input_ids=input_ids,
                        attention_mask=attention_mask,
                        token_type_ids=token_type_ids)

        pred = out.argmax(dim=1)

        correct += (pred == labels).sum().item()
        total += len(labels)

    print(correct / total)


test()

0
1
2
3
4
0.925
