In [6]:
import torch
import torch.nn as nn
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
import numpy as np
import pandas as pd
from tqdm import tqdm
import os

def split_train_valid(csv_path,train_prec=0.75):
    csv_data = pd.read_csv(csv_path,sep='\t') 
    text_data = csv_data.text
    label_data = csv_data.label
    
    mask = np.random.rand(len(text_data))<train_prec
    train_text_data = text_data[mask].reset_index(drop=True)
    valid_text_data = text_data[~mask].reset_index(drop=True)
    train_label_data = label_data[mask].reset_index(drop=True)
    valid_label_data = label_data[~mask].reset_index(drop=True)
    
    return train_text_data,valid_text_data,train_label_data,valid_label_data

class MyDataset(Dataset):
    def __init__(self,text_data,label_data,word_num):
        self.text_data = text_data
        self.label_data = label_data
        self.word_num = word_num
        
    def __getitem__(self,index):
        #所有text内的token索引增加1，0空出来代表空格，将每个text控制在一定长度内
        text_str = self.text_data[index]
        text_list = [int(x)+1 for x in text_str.split()]
        if len(text_list)>=self.word_num:
            text_list = text_list[:self.word_num]
        else:
            text_list.extend([0]*(self.word_num-len(text_list)))
        text_array = np.array(text_list)
        label_array = np.array(self.label_data[index])
        return text_array,label_array
        
    def __len__(self):
        return len(self.text_data)

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = nn.Embedding(7551,101)
        self.rnn = nn.GRU(101,50,batch_first=True,bidirectional=True)
        self.fc = nn.Linear(100,14)
        
    def forward(self,X):
#         print('X:',X.shape)
        X = self.embedding(X)
#         print('embedding:',X.shape)
        _,X = self.rnn(X)
#         print('rnn:',X.shape)
        X = X.permute(1,0,2)
        X = X.reshape((X.shape[0],-1))
#         print('trans:',X.shape)
        y_hat = self.fc(X)
#         print('y_hat:',y_hat.shape)
        return y_hat

    
class MyTrain():
    def __init__(self,max_epoch=1,random_seed=1,lr=0.001,out_dir='./',word_num= 2674):
        self.max_epoch = max_epoch
        self.random_seed = random_seed
        self.lr = lr
        self.out_dir = out_dir
        self.iter = 0
        self.word_num = word_num
        
    def fix_random(self):
        import random
        import numpy as np
        import torch
        random.seed(self.random_seed)
        np.random.seed(self.random_seed)
        torch.random.manual_seed(self.random_seed)
        torch.cuda.random.manual_seed_all(self.random_seed)
        torch.backends.cudnn.deterministic = True
        print(f'random seed:{self.random_seed}')
        
    def my_train(self):
        
        max_epoch,lr = self.max_epoch,self.lr
        
        if self.random_seed is not None:
            self.fix_random()
        
        train_text_data,valid_text_data,train_label_data,valid_label_data = split_train_valid('./train_set.csv',train_prec=0.75)
        train_dataset = MyDataset(train_text_data,train_label_data,self.word_num)
        valid_dataset = MyDataset(valid_text_data,valid_label_data,self.word_num)
        
        my_model = MyModel()
        my_optim = torch.optim.Adam(my_model.parameters(),lr=lr)
        my_loss = nn.CrossEntropyLoss()
        
        if torch.cuda.is_available():
            my_model.cuda()
            my_loss.cuda()
        print(f'train device:{next(iter(my_model.parameters())).device}')  #显示训练设备
        
        best_f1_score = 0
        epoch_index = 0

        for epoch_index in range(max_epoch):
            
            loss_list = []
            train_f1_score_list = []
            valid_f1_score_list = []
            train_dataloader = DataLoader(train_dataset,batch_size=60,shuffle=True)
            valid_dataloader = DataLoader(valid_dataset,batch_size=20,shuffle=True)
            my_dataloader = tqdm(train_dataloader)
            
            batch_index = 0
                        
            for train_data,valid_data in zip(my_dataloader,valid_dataloader):
                                
                my_model.train()  #将模型设置为训练模式
                train_text,train_label = train_data
                valid_text,valid_label = valid_data
                
                if torch.cuda.is_available():
                    train_text = train_text.cuda()
                    train_label = train_label.cuda()
                    valid_text = valid_text.cuda()
                    valid_label = valid_label.cuda()
                
                train_y_hat = my_model(train_text)
                batch_train_loss = my_loss(train_y_hat,train_label)
                
                my_optim.zero_grad()
                batch_train_loss.backward()
                my_optim.step()
                my_optim.param_groups[0]['lr'] = lr*(0.8**(epoch_index%10))
        
                my_model.eval()  #将模型设置为验证模式
                with torch.no_grad():
                    valid_y_hat = my_model(valid_text)
                    batch_valid_f1_score = self.f1_score(valid_y_hat.data,valid_label.data)
                    valid_f1_score_list.append(batch_valid_f1_score)
                    mean_valid_f1 = round(sum(valid_f1_score_list)/len(valid_f1_score_list),3)

                    #显示batch结果
                    batch_lr = round(my_optim.param_groups[0]['lr'],5)
                    batch_loss = round(batch_train_loss.item(),4)
                    loss_list.append(batch_loss)
                    mean_loss = round((sum(loss_list)/len(loss_list)),3)

                    batch_train_f1_score = self.f1_score(train_y_hat.data,train_label.data)
                    train_f1_score_list.append(batch_train_f1_score)
                    mean_train_f1 = round(sum(train_f1_score_list)/len(train_f1_score_list),3)

                    my_dataloader.set_description(
                        f'epoch:{epoch_index},batch:{batch_index},lr:{batch_lr},loss:{batch_loss},mean_loss:{mean_loss},train_f1:{mean_train_f1},valid_f1:{mean_valid_f1}')
                    batch_index+=1
                    
            torch.save(my_model.state_dict(),os.path.join(
                self.out_dir,
                f'embedding_gru_best_{self.word_num}word_valid_f1_score_{round(mean_valid_f1,4)}'))
        
        return my_model
        
    def f1_score(self,y_hat,label,eps=1e-8):
        #y_hat(N,C),label(1)
        y_hat = y_hat.cpu()
        label = label.cpu()
        preds_list = list(torch.argmax(y_hat,dim=1).numpy())
        label_list = list(label.numpy())
#         print(f'preds:{preds_list},label:{label_list}')
        class_index_list = []
        for class_index in label_list:
            if class_index not in class_index_list:
                class_index_list.append(class_index)

        f1_score_list = []
        for index in class_index_list:
            if index not in preds_list:
                sub_f1_score = 0
            else:
                tp = 0
                fp = 0
                fn = 0
                for i in range(len(preds_list)):
                    if preds_list[i] == index and label_list[i] == index:
                        tp+=1
                    if preds_list[i] == index and label_list[i] != index: 
                        fp+=1
                    if preds_list[i] != index and label_list[i] == index:
                        fn+=1
                prec_val = tp/(tp+fp) 
                recall_val = tp/(tp+fn)
                sub_f1_score = 2*(prec_val*recall_val)/(prec_val+recall_val+eps)
            f1_score_list.append(sub_f1_score)

        batch_f1_score = sum(f1_score_list)/len(f1_score_list)

        return batch_f1_score

In [7]:
if __name__ == '__main__':
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    rnn_train = MyTrain(max_epoch=10,random_seed=1,word_num=2674)
    my_model = rnn_train.my_train()

random seed:1
train device:cuda:0


epoch:0,batch:2496,lr:0.001,loss:0.2769,mean_loss:0.729,train_f1:0.635,valid_f1:0.685: 100%|▉| 2497/2502 [04:49<00:00, 
epoch:1,batch:2496,lr:0.0008,loss:0.2035,mean_loss:0.249,train_f1:0.889,valid_f1:0.898: 100%|▉| 2497/2502 [04:46<00:00,
epoch:2,batch:2496,lr:0.00064,loss:0.3138,mean_loss:0.184,train_f1:0.922,valid_f1:0.918: 100%|▉| 2497/2502 [04:46<00:00
epoch:3,batch:2496,lr:0.00051,loss:0.065,mean_loss:0.149,train_f1:0.937,valid_f1:0.922: 100%|▉| 2497/2502 [04:44<00:00,
epoch:4,batch:2496,lr:0.00041,loss:0.0498,mean_loss:0.126,train_f1:0.947,valid_f1:0.927: 100%|▉| 2497/2502 [04:45<00:00
epoch:5,batch:2496,lr:0.00033,loss:0.0228,mean_loss:0.108,train_f1:0.956,valid_f1:0.929: 100%|▉| 2497/2502 [04:46<00:00
epoch:6,batch:2496,lr:0.00026,loss:0.1217,mean_loss:0.095,train_f1:0.963,valid_f1:0.928: 100%|▉| 2497/2502 [04:40<00:00
epoch:7,batch:2496,lr:0.00021,loss:0.0343,mean_loss:0.084,train_f1:0.967,valid_f1:0.928: 100%|▉| 2497/2502 [04:44<00:00
epoch:8,batch:2496,lr:0.00017,loss:0.103