# 导入相关库

In [1]:
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader, TensorDataset
import numpy as np
import pandas as pd
import random
import re

from sklearn import preprocessing

# 指定使用的gpu
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '6'

# 数据处理

In [14]:
# 读取数据
all_data = pd.read_csv("datasets/text_classification/图书馆、情报与档案管理_体育学_法学.csv")

# 只使用摘要和学科
all_data = all_data[['学科','摘要']]
all_data.columns = ['label','text']

# 对学科进行label encoder
le = preprocessing.LabelEncoder()
le.fit(all_data['label'])
label2id = dict(zip(le.classes_, range(0, len(le.classes_))))
id2label = dict(map(reversed, label2id.items()))
print(label2id)
print(id2label)
all_data['label_id'] = all_data['label'].map(label2id)

print(all_data.info())
all_data.head()

{'体育学': 0, '图书馆、情报与档案管理': 1, '法学': 2}
{0: '体育学', 1: '图书馆、情报与档案管理', 2: '法学'}
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 14386 entries, 0 to 14385
Data columns (total 3 columns):
 #   Column    Non-Null Count  Dtype 
---  ------    --------------  ----- 
 0   label     14386 non-null  object
 1   text      14386 non-null  object
 2   label_id  14386 non-null  int64 
dtypes: int64(1), object(2)
memory usage: 337.3+ KB
None


Unnamed: 0,label,text,label_id
0,法学,"域外简易程序在适用范围上主要适用于罪行较轻的犯罪,在程序类型上普遍呈现出多样化的特点,在具体...",2
1,法学,"政府非税收入是中国财政收入的重要组成部分,然而其理论研究严重落后于现实需要。文章提出应创建""...",2
2,法学,集团仲裁是美国为消费者、雇佣、金融等格式合同的弱方当事人集合维权创设的一种纠纷解决方式。美国...,2
3,法学,"长期以来,以政府为主导的犯罪打击模式在我国犯罪治理中居于主导地位,但这种模式忽视了被害人等其...",2
4,法学,我国的有限责任公司盈余分配纠纷是有限责任公司股权纠纷的典型争议之一。经济欠发达地区的法院明显...,2


In [70]:
# 划分为训练集和验证集
x_train, x_test, train_label, test_label =  train_test_split(all_data['text'], 
                                                             all_data['label_id'], 
                                                             test_size=0.2, 
                                                             random_state = 123)


In [72]:
# 重新索引
x_train, x_test, train_label, test_label = x_train.reset_index(drop=True), x_test.reset_index(drop=True), train_label.reset_index(drop=True), test_label.reset_index(drop=True)

train_label

0        2
1        0
2        0
3        2
4        0
        ..
11503    0
11504    1
11505    2
11506    1
11507    2
Name: label_id, Length: 11508, dtype: int64

# 模型训练

## tokenizer

In [74]:
# pip install transformers
# transformers bert相关的模型使用和加载
from transformers import BertTokenizer
# 分词器，词典

# 第一次使用会下载bert-base-chinese到本地缓存，
# 或者自己到huggingface官网下载模型，这里填写模型路径
# 不想翻墙的话到https://github.com/ymcui/Chinese-BERT-wwm下载哈工大开源的版本
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
train_encoding = tokenizer(list(x_train), truncation=True, padding=True, max_length=128)
test_encoding = tokenizer(list(x_test), truncation=True, padding=True, max_length=128)

In [75]:
# input_ids：字的编码
# token_type_ids：标识是第一个句子还是第二个句子
# attention_mask：标识是不是填充, 是不是可以被attention到
encode_example = tokenizer(x_train[0])
decode_example = tokenizer.decode(encode_example.input_ids)

print(encode_example, decode_example)

{'input_ids': [101, 7270, 3309, 809, 3341, 8024, 5401, 1744, 2792, 3354, 2456, 4638, 809, 6887, 2548, 7344, 5576, 510, 3791, 2526, 3780, 5576, 510, 4664, 4719, 2829, 5576, 711, 4294, 2519, 4638, 5468, 6930, 3124, 2424, 4989, 860, 1353, 5576, 6571, 3322, 1169, 2190, 3300, 3126, 7564, 7344, 1469, 6883, 1169, 5576, 6571, 6629, 1168, 749, 1068, 7241, 868, 4500, 511, 2968, 4955, 5401, 1744, 5468, 6930, 3124, 2424, 1741, 5312, 794, 3124, 6887, 2548, 1469, 6121, 3124, 840, 4415, 2456, 6392, 2792, 6822, 6121, 4638, 4989, 3791, 1469, 2809, 3791, 2141, 6664, 8024, 5440, 2175, 1071, 6887, 2548, 3780, 5576, 6817, 6121, 3322, 1169, 8024, 2190, 2769, 1744, 1217, 2571, 3354, 2456, 5143, 5320, 510, 6226, 5745, 510, 3403, 1114, 4638, 6121, 3124, 840, 4415, 1169, 2428, 1265, 860, 5143, 510, 2130, 1587, 6887, 2548, 5276, 3338, 680, 3791, 3780, 2674, 2770, 4685, 5310, 1394, 4638, 1353, 5576, 3322, 1169, 1072, 3300, 671, 2137, 4638, 955, 7063, 2692, 721, 511, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0

In [76]:
# 数据集读取
class NewsDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels
    
    # 读取单个样本
    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(int(self.labels[idx]))
        return item
    
    def __len__(self):
        return len(self.labels)

train_dataset = NewsDataset(train_encoding, train_label)
test_dataset = NewsDataset(test_encoding, test_label)

In [79]:
train_dataset[0]

{'input_ids': tensor([ 101, 7270, 3309,  809, 3341, 8024, 5401, 1744, 2792, 3354, 2456, 4638,
          809, 6887, 2548, 7344, 5576,  510, 3791, 2526, 3780, 5576,  510, 4664,
         4719, 2829, 5576,  711, 4294, 2519, 4638, 5468, 6930, 3124, 2424, 4989,
          860, 1353, 5576, 6571, 3322, 1169, 2190, 3300, 3126, 7564, 7344, 1469,
         6883, 1169, 5576, 6571, 6629, 1168,  749, 1068, 7241,  868, 4500,  511,
         2968, 4955, 5401, 1744, 5468, 6930, 3124, 2424, 1741, 5312,  794, 3124,
         6887, 2548, 1469, 6121, 3124,  840, 4415, 2456, 6392, 2792, 6822, 6121,
         4638, 4989, 3791, 1469, 2809, 3791, 2141, 6664, 8024, 5440, 2175, 1071,
         6887, 2548, 3780, 5576, 6817, 6121, 3322, 1169, 8024, 2190, 2769, 1744,
         1217, 2571, 3354, 2456, 5143, 5320,  510, 6226, 5745,  510, 3403, 1114,
         4638, 6121, 3124,  840, 4415, 1169, 2428,  102]),
 'token_type_ids': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0

In [81]:
# 精度计算
def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)

## 加载预训练模型进行训练

In [82]:
from transformers import BertForSequenceClassification, AdamW, get_linear_schedule_with_warmup

model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=3)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# 单个读取到批量读取
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True)

# 优化方法
optim = AdamW(model.parameters(), lr=2e-5)
total_steps = len(train_loader) * 1
scheduler = get_linear_schedule_with_warmup(optim, 
                                            num_warmup_steps = 0, # Default value in run_glue.py
                                            num_training_steps = total_steps)

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForSequenceClassification: ['cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [83]:
# 训练函数
def train():
    model.train()
    total_train_loss = 0
    iter_num = 0
    total_iter = len(train_loader)
    for batch in train_loader:
        # 正向传播
        optim.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs[0]
        total_train_loss += loss.item()
        
        # 反向梯度信息
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
        # 参数更新
        optim.step()
        scheduler.step()

        iter_num += 1
        if(iter_num % 100==0):
            print("epoth: %d, iter_num: %d, loss: %.4f, %.2f%%" % (epoch, iter_num, loss.item(), iter_num/total_iter*100))
        
    print("Epoch: %d, Average training loss: %.4f"%(epoch, total_train_loss/len(train_loader)))
    
def validation():
    model.eval()
    total_eval_accuracy = 0
    total_eval_loss = 0
    for batch in test_dataloader:
        with torch.no_grad():
            # 正常传播
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        
        loss = outputs[0]
        logits = outputs[1]

        total_eval_loss += loss.item()
        logits = logits.detach().cpu().numpy()
        label_ids = labels.to('cpu').numpy()
        total_eval_accuracy += flat_accuracy(logits, label_ids)
        
    avg_val_accuracy = total_eval_accuracy / len(test_dataloader)
    print("Accuracy: %.4f" % (avg_val_accuracy))
    print("Average testing loss: %.4f"%(total_eval_loss/len(test_dataloader)))
    print("-------------------------------")
    

# 训练4个epochs
for epoch in range(4):
    print("------------Epoch: %d ----------------" % epoch)
    train()
    validation()

------------Epoch: 0 ----------------
epoth: 0, iter_num: 100, loss: 0.0134, 13.89%
epoth: 0, iter_num: 200, loss: 0.0046, 27.78%
epoth: 0, iter_num: 300, loss: 0.0037, 41.67%
epoth: 0, iter_num: 400, loss: 0.0022, 55.56%
epoth: 0, iter_num: 500, loss: 0.0020, 69.44%
epoth: 0, iter_num: 600, loss: 0.0045, 83.33%
epoth: 0, iter_num: 700, loss: 0.0022, 97.22%
Epoch: 0, Average training loss: 0.1065
Accuracy: 0.9854
Average testing loss: 0.0688
-------------------------------
------------Epoch: 1 ----------------
epoth: 1, iter_num: 100, loss: 0.3073, 13.89%
epoth: 1, iter_num: 200, loss: 0.0021, 27.78%
epoth: 1, iter_num: 300, loss: 0.0017, 41.67%
epoth: 1, iter_num: 400, loss: 0.0043, 55.56%
epoth: 1, iter_num: 500, loss: 0.0036, 69.44%
epoth: 1, iter_num: 600, loss: 0.0024, 83.33%
epoth: 1, iter_num: 700, loss: 0.0033, 97.22%
Epoch: 1, Average training loss: 0.0381
Accuracy: 0.9854
Average testing loss: 0.0688
-------------------------------
------------Epoch: 2 ----------------
epoth:

## 模型保存

In [None]:
# 1、只保存模型参数
# 保存
torch.save(model.state_dict(), '\parameter.pkl')
# 加载
model = TheModelClass(...)
model.load_state_dict(torch.load('\parameter.pkl'))


In [84]:
# 2、保存完整模型
# 保存
torch.save(model, '\text_classification_model.pkl')
# 加载
# model = torch.load('\model.pkl')

# 模型预测

In [100]:
predict_example = ['大数据时代的小数据因其个体化特征优势在信息资源个性化服务过程中发挥着无可比拟的优势',
                   '高校校规是高等学校公共权力行使的重要依据,是现代大学治理的介质文本,在法律地位上可类同于"规章以下的规范性文件";',
                   '为深入发掘足球运动多元社会功能,争取更多社会共识支持,服务社会发展进步,助力中央足球发展改革战略,运用文献资料、逻辑分析等研究方法,根据综合国力构成要素,从政治、经济、军事、教育、文化、科技6个角度分析足球运动发展倒逼中国社会全面改革的价值意义,认为:1)足球运动可以代替战争这一人类社会竞争最高形式,为社会发展提供变革动力,通过变革促进社会资源整合利用效率提高,在建设、传播国家文化软实力,扩大对外友好交流同时,培养奋发向上的民族意志和凝聚力,服务综合国力整体提升;2)从成功所需因素角度,在相同地缘、社会文化和外交联盟背景下,奥运象征核战力,足球象征常规战力,奥运和足球两种"和平年代的战争",对社会发展具有不同促进作用,奥运激励精英先富,足球要求全民共富。美国常规战力要由其美式足球运动和女子足球所获6个世界冠军体现。'
                  ]
predict_example = tokenizer(predict_example, padding=True, max_length=128)
# print(predict_example)
with torch.no_grad():
        # 正常传播
        input_ids = torch.tensor(predict_example['input_ids']).to(device)
        attention_mask =  torch.tensor(predict_example['attention_mask']).to(device)
#         labels = predict_example['labels'].to(device)
        outputs = model(input_ids, attention_mask=attention_mask)
        print(outputs)
        result = outputs['logits'].argmax(axis=1)
        print(result)

print([id2label[int(i)] for i in list(result)])


SequenceClassifierOutput(loss=None, logits=tensor([[-3.1282,  3.8290, -1.9724],
        [-2.0875, -1.7345,  5.3452],
        [ 5.5092, -2.2844, -1.1295]], device='cuda:0'), hidden_states=None, attentions=None)
tensor([1, 2, 0], device='cuda:0')
['图书馆、情报与档案管理', '法学', '体育学']
