In [1]:
import torch
from datasets import load_dataset

In [2]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self,split):
        self.dataset = load_dataset("seamew/ChnSentiCorp",cache_dir="./ChnSentiCorp",split=split)
    def __len__(self):
        return len(self.dataset)
    def __getitem__(self,i):
        text = self.dataset[i]["text"]
        label = self.dataset[i]["label"]
        return text,label

In [3]:
dataset = Dataset("train")

Found cached dataset chn_senti_corp (/Users/byron/dev/attention/nlp_learn/ChnSentiCorp/seamew___chn_senti_corp/default/0.0.0/1f242195a37831906957a11a2985a4329167e60657c07dc95ebe266c03fdfb85)


In [7]:
len(dataset),dataset[:10]

(9600,
 (['选择珠江花园的原因就是方便，有电动扶梯直接到达海边，周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般，但还算整洁。 泳池在大堂的屋顶，因此很小，不过女儿倒是喜欢。 包的早餐是西式的，还算丰富。 服务吗，一般',
   '15.4寸笔记本的键盘确实爽，基本跟台式机差不多了，蛮喜欢数字小键盘，输数字特方便，样子也很美观，做工也相当不错',
   '房间太小。其他的都一般。。。。。。。。。',
   '1.接电源没有几分钟,电源适配器热的不行. 2.摄像头用不起来. 3.机盖的钢琴漆，手不能摸，一摸一个印. 4.硬盘分区不好办.',
   '今天才知道这书还有第6卷,真有点郁闷:为什么同一套书有两种版本呢?当当网是不是该跟出版社商量商量,单独出个第6卷,让我们的孩子不会有所遗憾。',
   '机器背面似乎被撕了张什么标签，残胶还在。但是又看不出是什么标签不见了，该有的都在，怪',
   '呵呵，虽然表皮看上去不错很精致，但是我还是能看得出来是盗的。但是里面的内容真的不错，我妈爱看，我自己也学着找一些穴位。',
   '这本书实在是太烂了,以前听浙大的老师说这本书怎么怎么不对,哪些地方都是误导的还不相信,终于买了一本看一下,发现真是~~~无语,这种书都写得出来',
   '地理位置佳，在市中心。酒店服务好、早餐品种丰富。我住的商务数码房电脑宽带速度满意,房间还算干净，离湖南路小吃街近。',
   '5.1期间在这住的，位置还可以，在市委市政府附近，要去商业区和步行街得打车，屋里有蚊子，虽然空间挺大，晚上熄灯后把窗帘拉上简直是伸手不见五指，很适合睡觉，但是会被该死的蚊子吵醒！打死了两只，第二天早上还是发现又没打死的，卫生间挺大，但是设备很老旧。'],
  [1, 1, 0, 0, 1, 0, 0, 0, 1, 1]))

# 加载分词器

In [8]:
from transformers import AutoTokenizer
token = AutoTokenizer.from_pretrained("bert-base-chinese")
token

BertTokenizerFast(name_or_path='bert-base-chinese', vocab_size=21128, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True)

# 编码批处理函数

In [143]:
def collate_fn(data):
    sents = [i[0] for i in data]
    labels = [i[1] for i in data]
    
    data = token(
        sents,
        padding="max_length",
        max_length=512,
        return_tensors="pt",
        return_length=True,
        truncation=True
    )
    
    input_ids = data["input_ids"]
    attention_mask=data["attention_mask"]
    token_type_ids = data["token_type_ids"]
    labels = torch.LongTensor(labels)
    
    return input_ids,attention_mask,token_type_ids,labels

# 定义数据加载器

In [144]:
loader = torch.utils.data.DataLoader(
    dataset=dataset,
    batch_size=16,
    collate_fn=collate_fn,
    shuffle=True,
    drop_last=True
)


In [145]:
for i,(input_ids,attention_mask,token_type_ids,labels) in enumerate(loader):
    break;
print(len(loader))
print(input_ids,attention_mask,token_type_ids,labels)
print(input_ids.shape,attention_mask.shape,token_type_ids.shape,labels.shape)

600
tensor([[ 101, 7509, 1510,  ...,    0,    0,    0],
        [ 101, 1692, 7509,  ...,    0,    0,    0],
        [ 101, 2769,  679,  ...,    0,    0,    0],
        ...,
        [ 101,  122, 1184,  ...,    0,    0,    0],
        [ 101, 4384, 1862,  ...,    0,    0,    0],
        [ 101, 4692, 6821,  ...,    0,    0,    0]]) tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]) tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]) tensor([0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0])
torch.Size([16, 512]) torch.Size([16, 512]) torch.Size([16, 512]) torch.Size([16])


# 加载模型

In [146]:
from transformers import AutoModel
pretrained = AutoModel.from_pretrained("bert-base-chinese")

# 不训练，不需要计算梯度
for param in pretrained.parameters():
    param.requires_grad_(False)

out = pretrained(
    input_ids = input_ids,
    attention_mask = attention_mask,
    token_type_ids = token_type_ids
)

out.last_hidden_state.shape
    

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


torch.Size([16, 512, 768])

# 定义下游任务模型

In [147]:
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = torch.nn.Linear(768,2)
    def forward(self,input_ids,attention_mask,token_type_ids):
        with torch.no_grad():
            out = pretrained(
                input_ids = input_ids,
                attention_mask=attention_mask,
                token_type_ids = token_type_ids                            
                )
        out = self.fc(out.last_hidden_state[:,0])
        out = out.softmax(dim=1)
        return out

In [148]:
model = Model()
model(
    input_ids = input_ids,
    attention_mask= attention_mask,
    token_type_ids= token_type_ids
).shape

torch.Size([16, 2])

# 训练

In [150]:
from transformers import AdamW

optimizer = AdamW(model.parameters(),lr=5e-4)
criterion = torch.nn.CrossEntropyLoss()

model.train()
for i,(input_ids,attention_mask,token_type_ids,labels) in enumerate(loader):
    out = model(
        input_ids = input_ids,
        attention_mask= attention_mask,
        token_type_ids= token_type_ids
    )
    
    loss = criterion(out,labels)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    
    if i % 5 == 0:
        out = out.argmax(dim = 1)
        accuracy = (out == labels).sum().item() / len(labels)
        print(i,loss.item(),accuracy)
        
    if i == 300:
        break

0 0.7485172152519226 0.375
5 0.6585582494735718 0.625
10 0.6525106430053711 0.625
15 0.6303020119667053 0.75


KeyboardInterrupt: 