# 实战：中文分类

数据集是 `seamew/ChnSentiCorp`，共 9600 句话，每句话有个 label 表示这是正面情感还是负面情感，也就是我们要对这句话做**二分类**。

## 1. 加载 dataset

In [2]:
import torch
from datasets import load_dataset


#定义数据集
class Dataset(torch.utils.data.Dataset):
    def __init__(self, split):
        self.dataset = load_dataset(path='seamew/ChnSentiCorp', split=split)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        text = self.dataset[i]['text']
        label = self.dataset[i]['label']

        return text, label


dataset = Dataset('train')

len(dataset), dataset[0]

Found cached dataset chn_senti_corp (/root/.cache/huggingface/datasets/seamew___chn_senti_corp/default/0.0.0/1f242195a37831906957a11a2985a4329167e60657c07dc95ebe266c03fdfb85)


(9600,
 ('选择珠江花园的原因就是方便，有电动扶梯直接到达海边，周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般，但还算整洁。 泳池在大堂的屋顶，因此很小，不过女儿倒是喜欢。 包的早餐是西式的，还算丰富。 服务吗，一般',
  1))

## 2. 加载 tokenizer

In [3]:
from transformers import BertTokenizer

#加载字典和分词工具
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

tokenizer

BertTokenizer(name_or_path='bert-base-chinese', vocab_size=21128, model_max_length=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})

## 3. 定义 collate_fn

由于原始数据集的数据是文本类型，这里定义 `collate_fn` 函数，将加载的文本数据转换成经过 tokenizer 编码的结果

In [4]:
def collate_fn(data):
    """
    我们的原始数据集是一句话，这里对文本进行分词并进行编码
    :param data: 一个 batch 的 sentences
    """
    sents = [i[0] for i in data]
    labels = [i[1] for i in data]

    #编码
    data = tokenizer.batch_encode_plus(batch_text_or_text_pairs=sents,
                                   truncation=True,
                                   padding='max_length',
                                   max_length=500,
                                   return_tensors='pt',
                                   return_length=True)

    # input_ids: 编码之后的数字
    # attention_mask: 是补零的位置是0,其他位置是1
    input_ids = data['input_ids']
    attention_mask = data['attention_mask']
    token_type_ids = data['token_type_ids']
    labels = torch.LongTensor(labels)

    #print(data['length'], data['length'].max())

    return input_ids, attention_mask, token_type_ids, labels


#数据加载器
loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=16,
                                     collate_fn=collate_fn,
                                     shuffle=True,
                                     drop_last=True)

for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
    break

print(len(loader))
input_ids.shape, attention_mask.shape, token_type_ids.shape, labels

600


(torch.Size([16, 500]),
 torch.Size([16, 500]),
 torch.Size([16, 500]),
 tensor([1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0]))

## 4. 加载 pretrained model

这里选择加载 `bert-base-chinese` 模型

In [5]:
from transformers import BertModel

# 加载预训练模型
pretrained = BertModel.from_pretrained('bert-base-chinese')

# 不做 fine-tuning，不需要计算梯度
for param in pretrained.parameters():
    param.requires_grad_(False)

# 模型试算
out = pretrained(input_ids=input_ids,
           attention_mask=attention_mask,
           token_type_ids=token_type_ids)

out.last_hidden_state.shape  # [batch, 分词长度（一句话被切分成多少个 token）, token 的 embed_dim]

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


torch.Size([16, 500, 768])

## 5. 定义下游模型

这个下游模型只包含一个全连接的 FC 层，用来对 BERT 得到的 representation 做分类

In [6]:
#定义下游任务模型
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = torch.nn.Linear(768, 2)  # out_feature 为 2，因为要进行二分类

    def forward(self, input_ids, attention_mask, token_type_ids):
        with torch.no_grad():
            out = pretrained(input_ids=input_ids,
                       attention_mask=attention_mask,
                       token_type_ids=token_type_ids)

        out = self.fc(out.last_hidden_state[:, 0])  # 拿 [CLS] 这个 special token 做分类

        out = out.softmax(dim=1)

        return out


model = Model()

out = model(input_ids=input_ids,
      attention_mask=attention_mask,
      token_type_ids=token_type_ids)

print('shape: ', out.shape)
print('out[0]: \n', out[0])

shape:  torch.Size([16, 2])
out[0]: 
 tensor([0.5596, 0.4404], grad_fn=<SelectBackward0>)


## 6. 训练下游模型

只对额外加上的 FC 层进行训练，并不训练 pretrained 的 BERT 层。

经过简单的训练之后，就可以看到准确率已经达到了百分之七八十的样子，这就是使用 BERT 来抽取 feature 的威力。在以往想要达到这样的效果，往往需要特别大的数据量才可以，而且还难以收敛。

In [7]:
from transformers import AdamW

# 训练
optimizer = AdamW(model.parameters(), lr=5e-4)
criterion = torch.nn.CrossEntropyLoss()

model.train()
for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
    out = model(input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids)

    loss = criterion(out, labels)  # out: [B, 2]  labels: [B]
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if i % 5 == 0:
        out = out.argmax(dim=1)
        accuracy = (out == labels).sum().item() / len(labels)

        print(i, loss.item(), accuracy)

    if i == 30:
        break



0 0.7267506122589111 0.4375
5 0.6897894144058228 0.5
10 0.6477447152137756 0.75
15 0.6398004293441772 0.8125
20 0.6133301258087158 0.6875
25 0.5809521675109863 0.875
30 0.5991812944412231 0.75


## 7. 模型测试

In [8]:
#测试
def test():
    model.eval()
    correct = 0
    total = 0

    loader_test = torch.utils.data.DataLoader(dataset=Dataset('validation'),
                                              batch_size=32,
                                              collate_fn=collate_fn,
                                              shuffle=True,
                                              drop_last=True)

    for i, (input_ids, attention_mask, token_type_ids,
            labels) in enumerate(loader_test):

        if i == 5:
            break

        print(i)

        with torch.no_grad():
            out = model(input_ids=input_ids,
                        attention_mask=attention_mask,
                        token_type_ids=token_type_ids)

        out = out.argmax(dim=1)
        correct += (out == labels).sum().item()
        total += len(labels)

    print(correct / total)


test()

Found cached dataset chn_senti_corp (/root/.cache/huggingface/datasets/seamew___chn_senti_corp/default/0.0.0/1f242195a37831906957a11a2985a4329167e60657c07dc95ebe266c03fdfb85)


0
1
2
3
4
0.8625
