In [2]:
# 1是正向情感 0是负向情感
import pandas as pd
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from transformers import Trainer,TrainingArguments, BertTokenizer, BertModel, BertPreTrainedModel,BertConfig
from torch.utils.data import Dataset, DataLoader
from torch import nn
import warnings
warnings.filterwarnings('ignore')
import sys
sys.setrecursionlimit(3000)


def read_data(data_dir):
    data = pd.read_csv(data_dir)
    data['content'] = data['content'].fillna('')
    data['text'] = data['content']+data['level_1']+data['level_2']+data['level_3']+data['level_4']
    return data

def fill_paddings(data, maxlen):
    '''补全句长'''
    if len(data) < maxlen:
        pad_len = maxlen-len(data)
        paddings = [0 for _ in range(pad_len)]
        data = torch.tensor(data + paddings)
    else:
        data = torch.tensor(data[:maxlen])
    return data

class InputDataSet():

    def __init__(self,data,tokenizer,max_len):
        self.data = data
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self,):
        return len(self.data)

    def __getitem__(self, item):  # item是索引 用来取数据
        text = str(self.data['text'][item])
        labels = self.data['label'][item]
        labels = torch.tensor(labels, dtype=torch.long)

        ## 手动构建
        tokens = self.tokenizer.tokenize(text)
        tokens_ids = self.tokenizer.convert_tokens_to_ids(tokens)
        tokens_ids = [101] + tokens_ids + [102]
        input_ids = fill_paddings(tokens_ids,self.max_len)

        attention_mask = [1 for _ in range(len(tokens_ids))]
        attention_mask = fill_paddings(attention_mask,self.max_len)

        token_type_ids = [0 for _ in range(len(tokens_ids))]
        token_type_ids = fill_paddings(token_type_ids,self.max_len)

        return {
            'text':text,
            'input_ids':input_ids,
            'attention_mask':attention_mask,
            'token_type_ids':token_type_ids,
            'labels':labels

        }


if __name__ == '__main__':
    train_dir = 'data/train.csv'
    dev_dir = 'data/dev.csv'
    model_dir = 'bert-base-chinese'
    train = read_data(train_dir)
    test = read_data(dev_dir)
    tokenizer = BertTokenizer.from_pretrained(model_dir)
    train_dataset = InputDataSet(train,tokenizer=tokenizer, max_len=128)
    train_dataloader = DataLoader(train_dataset,batch_size=4)
    batch = next(iter(train_dataloader))

    print(batch)
    print(batch['input_ids'].shape)
    print(batch['attention_mask'].shape)
    print(batch['token_type_ids'].shape)
    print(batch['labels'].shape)

{'text': ['使用移动手动电动工具,外接线绝缘皮破损,应停止使用.工业/危化品类（现场）—2016版（二）电气安全6、移动用电产品、电动工具及照明1、移动使用的用电产品和I类电动工具的绝缘线，必须采用三芯(单相)或四芯(三相)多股铜芯橡套软线。', '一般工业/危化品类（现场）—2016版（一）消防检查1、防火巡查3、消防设施、器材和消防安全标志是否在位、完整；', '消防知识要加强工业/危化品类（现场）—2016版（一）消防检查2、防火检查6、重点工种人员以及其他员工消防知识的掌握情况；', '消防通道有货物摆放 清理不及时工业/危化品类（现场）—2016版（一）消防检查1、防火巡查3、消防设施、器材和消防安全标志是否在位、完整；'], 'input_ids': tensor([[ 101,  886, 4500, 4919, 1220, 2797, 1220, 4510, 1220, 2339, 1072,  117,
         1912, 2970, 5296, 5318, 5357, 4649, 4788, 2938,  117, 2418,  977, 3632,
          886, 4500,  119, 2339,  689,  120, 1314, 1265, 1501, 5102, 8020, 4385,
         1767, 8021,  100, 8112, 4276, 8020,  753, 8021, 4510, 3698, 2128, 1059,
          127,  510, 4919, 1220, 4500, 4510,  772, 1501,  510, 4510, 1220, 2339,
         1072, 1350, 4212, 3209,  122,  510, 4919, 1220,  886, 4500, 4638, 4500,
         4510,  772, 1501, 1469,  151, 5102, 4510, 1220, 2339, 1072, 4638, 5318,
         5357, 5296, 8024, 2553, 7557, 7023, 4500,  676, 5708,  113, 1296, 4685,
          1