In [1]:
from sklearn.model_selection import train_test_split
import sys
sys.path.append("D:/Experiment")
from MyKu import MyXLM_Base
from MyKu import processing
import torch
from sklearn import metrics
import time
import os
from tqdm import tqdm
from torch.optim import Adam
from torch import nn
import numpy as np
import pandas as pd

In [2]:
data = processing.load_swsr()

In [3]:
train, test = train_test_split(data, test_size=0.2, random_state=42)

In [4]:
# 训练准备阶段，设置超参数和全局变量

batch_size = 8
num_epoch = 10  # 训练轮次
check_step = 1  # 用以训练中途对模型进行检验：每check_step个epoch进行一次测试和保存模型

learning_rate = 1e-5  # 优化器的学习率

# 获取训练、测试数据、分类类别总数
en_train_data, en_test_data, de_train_data, de_test_data, hi_train_data, hi_test_data = processing.load_hasoc2020()
categories = 2

train_iter, test_iter = MyXLM_Base.load_xlm_data(
    train, test, batch_size)
#固定写法，可以牢记，cuda代表Gpu
# torch.cuda.is_available()可以查看当前Gpu是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 加载预训练模型，因为这里是英文数据集，需要用在英文上的预训练模型：bert-base-uncased
# uncased指该预训练模型对应的词表不区分字母的大小写
# 详情可了解：https://huggingface.co/bert-base-uncased
pretrained_model_name = 'xlm-roberta-base'
# 创建模型 BertSST2Model
model = MyXLM_Base.MyXlmModel(categories, pretrained_model_name)
# 固定写法，将模型加载到device上，
# 如果是GPU上运行，此时可以观察到GPU的显存增加
model.to(device)

# 训练过程
# Adam是最近较为常用的优化器，详情可查看：https://www.jianshu.com/p/aebcaf8af76e
optimizer = Adam(model.parameters(), learning_rate)  # 使用Adam优化器
loss = nn.CrossEntropyLoss()  # 使用crossentropy作为二分类任务的损失函数

# 记录当前训练时间，用以记录日志和存储
timestamp = time.strftime("%m_%d_%H_%M", time.localtime())


In [5]:
# TaskA

file_name = 'readme.md'
model_save_path = 'D:/Experiment_models_save/twitter-xlm-roberta-base-sentiment/hasoc2020/'
name = 'taskA.pth'
en_temp_best = 0.99
de_temp_best = 0.99
hi_temp_best = 0.73
for epoch in range(1, num_epoch + 1):
    MyXLM_Base.train(model, train_iter, device, optimizer, loss, epoch)
    # MyXLM_Base.train(model, de_train_iter, device, optimizer, loss, epoch)
    # MyXLM_Base.train(model, hi_train_iter, device, optimizer, loss, epoch)
    en_acc_score = MyXLM_Base.test(
        model, test_iter, device, epoch, file_name)
    # de_acc_score = MyXLM_Base.test(model, de_test_iter, device, epoch, file_name)
    # hi_acc_score = MyXLM_Base.test(model, hi_test_iter, device, epoch, file_name)
    print('\n\n')
    # if en_acc_score > en_temp_best and de_acc_score > de_temp_best and hi_acc_score > hi_temp_best:
    #     en_temp_best, de_temp_best, hi_temp_best = en_acc_score, de_acc_score, hi_acc_score
    #     MyXLM_Base.save_pretrained(model, model_save_path, name)
    #     print(f'best en_acc_socre : {en_acc_score}, best de_acc_score : {de_acc_score}, best hi_acc_score : {hi_acc_score}')
print(
    f'best en_acc_socre : {en_temp_best}, best de_acc_score : {de_temp_best}, best hi_acc_score : {hi_temp_best}')


Training Epoch 1: 100%|[31m██████████[0m| 897/897 [04:14<00:00,  3.52it/s]
Testing: 100%|[32m██████████[0m| 1794/1794 [00:20<00:00, 88.92it/s]


[[1019  147]
 [ 214  414]]
              precision    recall  f1-score   support

           0       0.83      0.87      0.85      1166
           1       0.74      0.66      0.70       628

    accuracy                           0.80      1794
   macro avg       0.78      0.77      0.77      1794
weighted avg       0.80      0.80      0.80      1794

Acc : 0.7987736900780379	 F1: 0.7729520745783128





Training Epoch 2: 100%|[31m██████████[0m| 897/897 [04:13<00:00,  3.54it/s]
Testing: 100%|[32m██████████[0m| 1794/1794 [00:20<00:00, 88.54it/s] 


[[960 206]
 [123 505]]
              precision    recall  f1-score   support

           0       0.89      0.82      0.85      1166
           1       0.71      0.80      0.75       628

    accuracy                           0.82      1794
   macro avg       0.80      0.81      0.80      1794
weighted avg       0.82      0.82      0.82      1794

Acc : 0.8166109253065775	 F1: 0.8040035053335461





Training Epoch 3: 100%|[31m██████████[0m| 897/897 [04:15<00:00,  3.51it/s]
Testing: 100%|[32m██████████[0m| 1794/1794 [00:20<00:00, 87.88it/s]


[[1051  115]
 [ 238  390]]
              precision    recall  f1-score   support

           0       0.82      0.90      0.86      1166
           1       0.77      0.62      0.69       628

    accuracy                           0.80      1794
   macro avg       0.79      0.76      0.77      1794
weighted avg       0.80      0.80      0.80      1794

Acc : 0.8032329988851727	 F1: 0.7723247942218538





Training Epoch 4: 100%|[31m██████████[0m| 897/897 [04:18<00:00,  3.47it/s]
Testing: 100%|[32m██████████[0m| 1794/1794 [00:20<00:00, 87.79it/s]


[[1017  149]
 [ 195  433]]
              precision    recall  f1-score   support

           0       0.84      0.87      0.86      1166
           1       0.74      0.69      0.72       628

    accuracy                           0.81      1794
   macro avg       0.79      0.78      0.79      1794
weighted avg       0.81      0.81      0.81      1794

Acc : 0.8082497212931996	 F1: 0.785521550855292





Training Epoch 5: 100%|[31m██████████[0m| 897/897 [04:27<00:00,  3.35it/s]
Testing: 100%|[32m██████████[0m| 1794/1794 [00:20<00:00, 87.03it/s]


[[1003  163]
 [ 193  435]]
              precision    recall  f1-score   support

           0       0.84      0.86      0.85      1166
           1       0.73      0.69      0.71       628

    accuracy                           0.80      1794
   macro avg       0.78      0.78      0.78      1794
weighted avg       0.80      0.80      0.80      1794

Acc : 0.8015607580824972	 F1: 0.7794525335208224





Training Epoch 6: 100%|[31m██████████[0m| 897/897 [04:27<00:00,  3.35it/s]
Testing: 100%|[32m██████████[0m| 1794/1794 [00:20<00:00, 86.79it/s]


[[926 240]
 [124 504]]
              precision    recall  f1-score   support

           0       0.88      0.79      0.84      1166
           1       0.68      0.80      0.73       628

    accuracy                           0.80      1794
   macro avg       0.78      0.80      0.79      1794
weighted avg       0.81      0.80      0.80      1794

Acc : 0.7971014492753623	 F1: 0.7852169748765931





Training Epoch 7: 100%|[31m██████████[0m| 897/897 [04:27<00:00,  3.36it/s]
Testing: 100%|[32m██████████[0m| 1794/1794 [00:20<00:00, 87.75it/s] 


[[912 254]
 [137 491]]
              precision    recall  f1-score   support

           0       0.87      0.78      0.82      1166
           1       0.66      0.78      0.72       628

    accuracy                           0.78      1794
   macro avg       0.76      0.78      0.77      1794
weighted avg       0.80      0.78      0.79      1794

Acc : 0.782051282051282	 F1: 0.7693492196324142





Training Epoch 8: 100%|[31m██████████[0m| 897/897 [04:27<00:00,  3.36it/s]
Testing: 100%|[32m██████████[0m| 1794/1794 [00:27<00:00, 65.24it/s]


[[935 231]
 [147 481]]
              precision    recall  f1-score   support

           0       0.86      0.80      0.83      1166
           1       0.68      0.77      0.72       628

    accuracy                           0.79      1794
   macro avg       0.77      0.78      0.77      1794
weighted avg       0.80      0.79      0.79      1794

Acc : 0.7892976588628763	 F1: 0.7748804907845117





Training Epoch 9: 100%|[31m██████████[0m| 897/897 [05:17<00:00,  2.82it/s]
Testing: 100%|[32m██████████[0m| 1794/1794 [00:22<00:00, 80.74it/s]


[[1084   82]
 [ 325  303]]
              precision    recall  f1-score   support

           0       0.77      0.93      0.84      1166
           1       0.79      0.48      0.60       628

    accuracy                           0.77      1794
   macro avg       0.78      0.71      0.72      1794
weighted avg       0.78      0.77      0.76      1794

Acc : 0.7731326644370122	 F1: 0.7200824236383327





Training Epoch 10: 100%|[31m██████████[0m| 897/897 [05:18<00:00,  2.82it/s]
Testing: 100%|[32m██████████[0m| 1794/1794 [00:22<00:00, 80.51it/s]

[[881 285]
 [131 497]]
              precision    recall  f1-score   support

           0       0.87      0.76      0.81      1166
           1       0.64      0.79      0.70       628

    accuracy                           0.77      1794
   macro avg       0.75      0.77      0.76      1794
weighted avg       0.79      0.77      0.77      1794

Acc : 0.7681159420289855	 F1: 0.7569818103667232



best en_acc_socre : 0.99, best de_acc_score : 0.99, best hi_acc_score : 0.73



