In [3]:
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel
from tqdm import tqdm
import os
import time
from transformers import BertTokenizer
from transformers import logging
import processing
from sklearn import metrics
import MyBERT

In [4]:

# 训练准备阶段，设置超参数和全局变量
file_name = 'readme.md'
batch_size = 16
num_epoch = 10  # 训练轮次
check_step = 1  # 用以训练中途对模型进行检验：每check_step个epoch进行一次测试和保存模型

learning_rate = 1e-5  # 优化器的学习率

# 获取训练、测试数据、分类类别总数
train_data = processing.get_exist2021_data_temp(type='train', len=3437)
test_data = processing.get_exist2021_data_temp(type='test', len=1000)
categories = 2

train_iter, test_iter = MyBERT.load_bert_data(train_data, test_data, batch_size)

#固定写法，可以牢记，cuda代表Gpu
# torch.cuda.is_available()可以查看当前Gpu是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 加载预训练模型，因为这里是英文数据集，需要用在英文上的预训练模型：bert-base-uncased
# uncased指该预训练模型对应的词表不区分字母的大小写
# 详情可了解：https://huggingface.co/bert-base-uncased
pretrained_model_name = 'bert-base-uncased'
# 创建模型 BertSST2Model
model = MyBERT.MyBertModel(categories, pretrained_model_name)
# 固定写法，将模型加载到device上，
# 如果是GPU上运行，此时可以观察到GPU的显存增加
model.to(device)

# 训练过程
# Adam是最近较为常用的优化器，详情可查看：https://www.jianshu.com/p/aebcaf8af76e
optimizer = Adam(model.parameters(), learning_rate)  # 使用Adam优化器
loss = nn.CrossEntropyLoss()  # 使用crossentropy作为二分类任务的损失函数

# 记录当前训练时间，用以记录日志和存储
timestamp = time.strftime("%m_%d_%H_%M", time.localtime())


In [5]:
fp = open(file_name, 'a+')
for epoch in range(1, num_epoch + 1):
    MyBERT.train(model, train_iter,device, optimizer, loss, epoch)
    MyBERT.test(model, test_iter, device, epoch, file_name)

Training Epoch 1: 100%|[31m██████████[0m| 215/215 [00:48<00:00,  4.41it/s]
Testing: 100%|[32m██████████[0m| 999/999 [00:11<00:00, 84.35it/s]


[[289 168]
 [ 97 445]]
              precision    recall  f1-score   support

           0       0.75      0.63      0.69       457
           1       0.73      0.82      0.77       542

    accuracy                           0.73       999
   macro avg       0.74      0.73      0.73       999
weighted avg       0.74      0.73      0.73       999

Acc : 0.7347347347347347	 F1: 0.7705627705627706


Training Epoch 2: 100%|[31m██████████[0m| 215/215 [00:47<00:00,  4.51it/s]
Testing: 100%|[32m██████████[0m| 999/999 [00:11<00:00, 85.34it/s]


[[340 117]
 [147 395]]
              precision    recall  f1-score   support

           0       0.70      0.74      0.72       457
           1       0.77      0.73      0.75       542

    accuracy                           0.74       999
   macro avg       0.73      0.74      0.73       999
weighted avg       0.74      0.74      0.74       999

Acc : 0.7357357357357357	 F1: 0.7495256166982922


Training Epoch 3:  33%|[31m███▎      [0m| 71/215 [00:15<00:32,  4.49it/s]


KeyboardInterrupt: 