In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from torch.nn.utils import clip_grad_norm_
import numpy as np
import pandas as pd
import jieba
from tqdm import tqdm
from torchtext import data, datasets
from torchsummary import summary
from gensim.models import KeyedVectors

jieba.load_userdict('../word2vec/500000-dict.txt')

Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Loading model cost 0.511 seconds.
Prefix dict has been built succesfully.


In [2]:
'''
词典类，记录数据集中数显的词汇及其索引和数量
'''
class Dictionary:
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.word2count = {}
        
    def add_word(self, word):
        # build word - count
        if word not in self.word2count:
            self.word2count[word] = 1
        else:
            self.word2count[word] += 1
    
    '''
    仅使用训练集中出现次数 >= 阈值的词汇构建词表
    '''
    def cut_dict(self, cut=2):
        self.word2idx = {'<PAD>':0, '<UNK>':1}
        self.idx2word = {0:'<PAD>', 1:'<UNK>'}
        word2count = {}
        idx = 2
        
        for word in self.word2count.keys():
            if self.word2count[word] >= cut:
                self.word2idx[word] = idx
                self.idx2word[idx] = word
                idx += 1
                word2count[word] = self.word2count[word]
                
        self.word2count = word2count
        

'''
构建停止词表
'''
stop_words = []
with open('../word2vec/stopwords.txt') as f:
    for line in f:
        stop_words.append(line.strip())

In [3]:
'''
对评论进行中文分词
'''
def seg(data_frame):
    segs = jieba.cut(data_frame['review'], cut_all=False)
    l = []
    for seg in segs:
        if seg not in stop_words and seg != ' ':
            l.append(seg)
    return l
    
df = pd.read_csv('../datasets/online_shopping_10_cats.csv')
df = df[['label', 'review']].dropna().copy()  # 删除具有缺省的行

df['review'] = df.apply(seg, axis=1)

'''
简单的样本数据统计
'''
length_list = [len(review) for review in df['review']]
print('数据集样本最小长度: {}\n'
      '数据集样本最大长度: {}\n'
      '数据集样本平均长度: {:.1f}'
      .format(min(length_list), max(length_list), sum(length_list)/len(length_list)))

print('正样本数量: {}\n'
      '负样本数量: {}'
      .format(len(df[df['label']==1]), len(df[df['label']==0])))

数据集样本最小长度: 1
数据集样本最大长度: 1275
数据集样本平均长度: 26.0
正样本数量: 31727
负样本数量: 31046


In [4]:
'''
加载词向量
'''
vector = KeyedVectors.load_word2vec_format('../word2vec/500000-small.txt')

  'See the migration notes for details: %s' % _MIGRATION_NOTES_URL


In [5]:
'''
获取数据集字典
'''

dictionary = Dictionary()

for review in df['review']:
    for i, word in enumerate(review):
        dictionary.add_word(word)
        
# 使用出现次数至少为 2 的词构建词典
dictionary.cut_dict(2)

In [6]:
'''
对训数据中的词汇进行统计
'''
em_word = 0
for word in dictionary.word2idx.keys():
    if word in vector:
        em_word += 1

print('词汇总量为: {}'
      '\n具有词向量的词汇量为: {}'
      '\n占比为: {:.2f}%\n'
      .format(len(dictionary.word2idx), em_word, em_word/len(dictionary.word2idx)*100))

'''
对训练集中的词汇数进行统计
'''
# 按照词汇出现数量进行排序，大→小
dict_list = [(key, value) for key, value in dictionary.word2count.items()]
dict_list.sort(key=lambda x:x[1], reverse=True)

# 统计词数总量和具有词向量的词汇量
total_count = 0
em_count = 0
for d in dict_list:
    total_count += d[1]
    if d[0] in vector:
        em_count += d[1]
print('词汇总数为: {}'
      '\n具有词向量的词汇量为: {}'
      '\n占比为: {:.2f}%'
      .format(total_count, em_count, em_count/total_count*100))

词汇总量为: 50201
具有词向量的词汇量为: 47980
占比为: 95.58%

词汇总数为: 1585320
具有词向量的词汇量为: 1567632
占比为: 98.88%


In [7]:
'''
构建训练集、测试集序列化数据
'''
EMBEDDING_DIM = 200
# 词嵌入矩阵
embedding_matrix = np.zeros((len(dictionary.word2idx), EMBEDDING_DIM), dtype=np.float32)

# 依据 word2idx 的映射，为 embedding_matrix 填充词向量
texts = []
labels = []
for review in df['review']:
    text = []
    for i, word in enumerate(review):
        # 如果在字典中没有找到该词汇，设置为 1 即 <UNK>
        embedding_index = dictionary.word2idx.get(word, 1)
        text.append(embedding_index)
        # 如果该词有预训练词向量，初始化
        if word in vector:
            embedding_matrix[embedding_index, :] = vector[word]
    texts.append(text)

# 构建词嵌入 Tensor
embedding_matrix = torch.Tensor(embedding_matrix)


for label in df['label']:
    labels.append(label)
    
'''
对数据集进行 padding/truncation 操作, pre / post
'''
FIXED_LENGTH = 25

def pad_truncate(dataset, length=50, type='pre'):
    fixed_texts = []
    
    for text in dataset:
        fixed_text = []
        
        if len(text) < length:
            fixed_text = text[:] + [0] * (length-len(text))
        else:
            if type == 'pre':  # 截取前面 length 个词
                fixed_text = text[:length]
            elif type == 'post':  # 截取后面 length 个词
                fixed_text = text[-length:]
                
        fixed_texts.append(fixed_text)
        
    return fixed_texts

texts = pad_truncate(texts, length=FIXED_LENGTH, type='pre')

In [8]:
'''
划分训练集、验证集(5000)
'''
np.random.seed(1)
data_size = len(texts)
indices = np.arange(data_size)
np.random.shuffle(indices)

# 训练集
train_texts = np.array(texts)[indices[:-5000]]
train_labels = np.array(labels)[indices[:-5000]]
# 验证集
val_texts = np.array(texts)[indices[-5000:]]
val_labels = np.array(labels)[indices[-5000:]]

train_x = torch.Tensor(train_texts).long()
train_y = torch.Tensor(train_labels).long()
val_x = torch.Tensor(val_texts).long()
val_y = torch.Tensor(val_labels).long()

In [22]:
'''
TextCNN 模型
'''
    
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def create_embedding_layer(embedding_matrix, trainable=True):
        vocab_size, embed_dim = embedding_matrix.size()
        embed_layer = nn.Embedding(vocab_size, embed_dim)
        embed_layer.load_state_dict({'weight': embedding_matrix})
        embed_layer.weight.requires_grad = trainable
        return embed_layer, vocab_size, embed_dim
    
class TextCNN(nn.Module):
    def __init__(self, embedding_matrix):
        super(TextCNN, self).__init__()
        self.embed, self.vocab_size, self.embed_dim = create_embedding_layer(embedding_matrix, trainable=True)

        self.conv1d_1 = nn.Conv1d(200, 100, 2) # 输入的 channels 为 embedding_dim
        self.conv1d_2 = nn.Conv1d(200, 100, 3)
        self.conv1d_3 = nn.Conv1d(200, 100, 4)
        
        nn.init.kaiming_normal_(self.conv1d_1.weight.data)
        nn.init.kaiming_normal_(self.conv1d_2.weight.data)
        nn.init.kaiming_normal_(self.conv1d_3.weight.data)
        
        self.relu = nn.ReLU()
        
        self.maxpool_1 = nn.AdaptiveMaxPool1d(1)
        self.maxpool_2 = nn.AdaptiveMaxPool1d(1)
        self.maxpool_3 = nn.AdaptiveMaxPool1d(1)
        
        self.dropout_1 = nn.Dropout(0.7)
        self.fc_1 = nn.Linear(100*3, 32)
        nn.init.kaiming_normal_(self.fc_1.weight.data)
        
        self.dropout_2 = nn.Dropout(0.7)
        self.fc_2 = nn.Linear(32, 2)
        nn.init.kaiming_normal_(self.fc_2.weight.data)
        
    def forward(self, x):
        out = self.embed(x)
        # (bs, sentence_len, embedding_dim) → (bs, embedding_dim, sentence_len)
        out = out.permute([0, 2, 1])
        
        out_1 = self.conv1d_1(out)
        out_2 = self.conv1d_2(out)
        out_3 = self.conv1d_3(out)
        # (bs, 64, sentence_len - filter_size + 1)
        
        out_1 = self.relu(out_1)
        out_2 = self.relu(out_2)
        out_3 = self.relu(out_3)
        
        # 在第 3 维度上取最大值 (bs, 64, 1)
        out_1 = self.maxpool_1(out_1).squeeze(2)
        out_2 = self.maxpool_2(out_2).squeeze(2)
        out_3 = self.maxpool_3(out_3).squeeze(2)
        
        cat = torch.cat((out_1, out_2, out_3), dim=1)
        
        out = self.dropout_1(cat)
        out = self.fc_1(out)
        out = self.dropout_2(out)
        out = self.fc_2(out)
        return out

In [21]:
'''
手动进行 SGD 训练
'''
# hyperparameters
batch_size = 32
epochs = 5
learning_rate = 0.0003

'''
使用 TensorDataset 和 DataLoader 存放数据集
'''
# 加载为 TensorDataset 对象
train_ds = TensorDataset(train_x, train_y)
val_ds = TensorDataset(val_x, val_y)

# 加载为 DataLoader 对象
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=len(val_ds), shuffle=False)

# initialize Model
model = TextCNN(embedding_matrix).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# fit
for epoch in range(epochs):
    model.train()
    
    with tqdm(enumerate(train_dl)) as t:
        for i, (x, y) in t:
            x = x.to(device)
            y = y.to(device)
            
            out = model(x)
            loss = criterion(out, y)

            # backprop and optimizer
            model.zero_grad()
            loss.backward()
            clip_grad_norm_(model.parameters(), 2.0)
            optimizer.step()

            if (i+1) % 50 == 0:
                print('Epoch: [{}/{}], 正在训练：[{}/{}] 个样本'
                      .format(epoch+1, epochs, (i+1)*batch_size, len(train_x)))

                with torch.no_grad():
                    out_ = model(x)
                    correct = torch.sum(torch.argmax(out_, dim=1) == y).item()
                    print('Val: {:.3f}%'.format(correct/batch_size*100))
t.close()

51it [00:06,  9.35it/s]

Epoch: [1/5], 正在训练：[1600/57773] 个样本
Val: 59.375%


101it [00:12,  9.02it/s]

Epoch: [1/5], 正在训练：[3200/57773] 个样本
Val: 84.375%


151it [00:19,  6.95it/s]

Epoch: [1/5], 正在训练：[4800/57773] 个样本
Val: 65.625%


201it [00:26,  8.89it/s]

Epoch: [1/5], 正在训练：[6400/57773] 个样本
Val: 71.875%


251it [00:32,  8.03it/s]

Epoch: [1/5], 正在训练：[8000/57773] 个样本
Val: 81.250%


301it [00:39,  8.99it/s]

Epoch: [1/5], 正在训练：[9600/57773] 个样本
Val: 78.125%


351it [00:45,  6.42it/s]

Epoch: [1/5], 正在训练：[11200/57773] 个样本
Val: 84.375%


401it [00:52,  7.23it/s]

Epoch: [1/5], 正在训练：[12800/57773] 个样本
Val: 81.250%


451it [00:58,  8.98it/s]

Epoch: [1/5], 正在训练：[14400/57773] 个样本
Val: 84.375%


501it [01:04,  6.42it/s]

Epoch: [1/5], 正在训练：[16000/57773] 个样本
Val: 84.375%


552it [01:11, 10.07it/s]

Epoch: [1/5], 正在训练：[17600/57773] 个样本
Val: 81.250%


601it [01:17,  9.01it/s]

Epoch: [1/5], 正在训练：[19200/57773] 个样本
Val: 71.875%


651it [01:24,  6.47it/s]

Epoch: [1/5], 正在训练：[20800/57773] 个样本
Val: 90.625%


701it [01:31,  6.30it/s]

Epoch: [1/5], 正在训练：[22400/57773] 个样本
Val: 87.500%


751it [01:40,  6.43it/s]

Epoch: [1/5], 正在训练：[24000/57773] 个样本
Val: 68.750%


801it [01:46,  8.48it/s]

Epoch: [1/5], 正在训练：[25600/57773] 个样本
Val: 93.750%


851it [01:54,  6.11it/s]

Epoch: [1/5], 正在训练：[27200/57773] 个样本
Val: 87.500%


901it [02:01,  7.54it/s]

Epoch: [1/5], 正在训练：[28800/57773] 个样本
Val: 87.500%


951it [02:07,  8.47it/s]

Epoch: [1/5], 正在训练：[30400/57773] 个样本
Val: 84.375%


1001it [02:13,  7.63it/s]

Epoch: [1/5], 正在训练：[32000/57773] 个样本
Val: 93.750%


1051it [02:20,  7.35it/s]

Epoch: [1/5], 正在训练：[33600/57773] 个样本
Val: 84.375%


1101it [02:27,  7.04it/s]

Epoch: [1/5], 正在训练：[35200/57773] 个样本
Val: 84.375%


1150it [02:36,  5.31it/s]

Epoch: [1/5], 正在训练：[36800/57773] 个样本
Val: 90.625%


1201it [02:44,  7.09it/s]

Epoch: [1/5], 正在训练：[38400/57773] 个样本
Val: 87.500%


1251it [02:51,  6.65it/s]

Epoch: [1/5], 正在训练：[40000/57773] 个样本
Val: 87.500%


1300it [02:59,  5.49it/s]

Epoch: [1/5], 正在训练：[41600/57773] 个样本
Val: 90.625%


1350it [03:08,  5.09it/s]

Epoch: [1/5], 正在训练：[43200/57773] 个样本
Val: 93.750%


1401it [03:18,  5.31it/s]

Epoch: [1/5], 正在训练：[44800/57773] 个样本
Val: 87.500%


1451it [03:27,  6.63it/s]

Epoch: [1/5], 正在训练：[46400/57773] 个样本
Val: 90.625%


1501it [03:34,  6.68it/s]

Epoch: [1/5], 正在训练：[48000/57773] 个样本
Val: 78.125%


1550it [03:43,  6.14it/s]

Epoch: [1/5], 正在训练：[49600/57773] 个样本
Val: 93.750%


1601it [03:51,  6.46it/s]

Epoch: [1/5], 正在训练：[51200/57773] 个样本
Val: 81.250%


1651it [03:59,  5.33it/s]

Epoch: [1/5], 正在训练：[52800/57773] 个样本
Val: 90.625%


1701it [04:07,  5.72it/s]

Epoch: [1/5], 正在训练：[54400/57773] 个样本
Val: 93.750%


1751it [04:16,  5.30it/s]

Epoch: [1/5], 正在训练：[56000/57773] 个样本
Val: 100.000%


1800it [04:25,  4.97it/s]

Epoch: [1/5], 正在训练：[57600/57773] 个样本
Val: 100.000%


1806it [04:27,  5.35it/s]
51it [00:09,  5.95it/s]

Epoch: [2/5], 正在训练：[1600/57773] 个样本
Val: 87.500%


101it [00:18,  5.78it/s]

Epoch: [2/5], 正在训练：[3200/57773] 个样本
Val: 84.375%


150it [00:27,  4.74it/s]

Epoch: [2/5], 正在训练：[4800/57773] 个样本
Val: 96.875%


200it [00:37,  4.99it/s]

Epoch: [2/5], 正在训练：[6400/57773] 个样本
Val: 93.750%


250it [00:46,  4.68it/s]

Epoch: [2/5], 正在训练：[8000/57773] 个样本
Val: 81.250%


301it [00:55,  6.39it/s]

Epoch: [2/5], 正在训练：[9600/57773] 个样本
Val: 87.500%


350it [01:03,  4.69it/s]

Epoch: [2/5], 正在训练：[11200/57773] 个样本
Val: 93.750%


401it [01:11,  6.45it/s]

Epoch: [2/5], 正在训练：[12800/57773] 个样本
Val: 100.000%


451it [01:21,  6.43it/s]

Epoch: [2/5], 正在训练：[14400/57773] 个样本
Val: 90.625%


501it [01:29,  5.58it/s]

Epoch: [2/5], 正在训练：[16000/57773] 个样本
Val: 90.625%


551it [01:37,  5.38it/s]

Epoch: [2/5], 正在训练：[17600/57773] 个样本
Val: 93.750%


601it [01:46,  6.59it/s]

Epoch: [2/5], 正在训练：[19200/57773] 个样本
Val: 90.625%


651it [01:54,  6.76it/s]

Epoch: [2/5], 正在训练：[20800/57773] 个样本
Val: 96.875%


701it [02:02,  6.53it/s]

Epoch: [2/5], 正在训练：[22400/57773] 个样本
Val: 87.500%


751it [02:10,  5.78it/s]

Epoch: [2/5], 正在训练：[24000/57773] 个样本
Val: 96.875%


801it [02:18,  6.45it/s]

Epoch: [2/5], 正在训练：[25600/57773] 个样本
Val: 93.750%


850it [02:27,  5.08it/s]

Epoch: [2/5], 正在训练：[27200/57773] 个样本
Val: 81.250%


901it [02:35,  6.90it/s]

Epoch: [2/5], 正在训练：[28800/57773] 个样本
Val: 90.625%


951it [02:43,  7.28it/s]

Epoch: [2/5], 正在训练：[30400/57773] 个样本
Val: 93.750%


1001it [02:52,  5.13it/s]

Epoch: [2/5], 正在训练：[32000/57773] 个样本
Val: 93.750%


1051it [03:01,  6.22it/s]

Epoch: [2/5], 正在训练：[33600/57773] 个样本
Val: 90.625%


1101it [03:09,  5.98it/s]

Epoch: [2/5], 正在训练：[35200/57773] 个样本
Val: 84.375%


1151it [03:18,  6.61it/s]

Epoch: [2/5], 正在训练：[36800/57773] 个样本
Val: 87.500%


1200it [03:26,  5.09it/s]

Epoch: [2/5], 正在训练：[38400/57773] 个样本
Val: 84.375%


1251it [03:34,  6.20it/s]

Epoch: [2/5], 正在训练：[40000/57773] 个样本
Val: 93.750%


1300it [03:44,  5.04it/s]

Epoch: [2/5], 正在训练：[41600/57773] 个样本
Val: 78.125%


1351it [03:53,  6.87it/s]

Epoch: [2/5], 正在训练：[43200/57773] 个样本
Val: 93.750%


1400it [04:02,  4.21it/s]

Epoch: [2/5], 正在训练：[44800/57773] 个样本
Val: 90.625%


1450it [04:13,  4.32it/s]

Epoch: [2/5], 正在训练：[46400/57773] 个样本
Val: 93.750%


1500it [04:22,  5.02it/s]

Epoch: [2/5], 正在训练：[48000/57773] 个样本
Val: 90.625%


1550it [04:32,  5.38it/s]

Epoch: [2/5], 正在训练：[49600/57773] 个样本
Val: 87.500%


1601it [04:41,  7.11it/s]

Epoch: [2/5], 正在训练：[51200/57773] 个样本
Val: 84.375%


1651it [04:48,  6.44it/s]

Epoch: [2/5], 正在训练：[52800/57773] 个样本
Val: 87.500%


1700it [04:57,  6.15it/s]

Epoch: [2/5], 正在训练：[54400/57773] 个样本
Val: 93.750%


1751it [05:04,  6.82it/s]

Epoch: [2/5], 正在训练：[56000/57773] 个样本
Val: 93.750%


1801it [05:12,  6.56it/s]

Epoch: [2/5], 正在训练：[57600/57773] 个样本
Val: 90.625%


1806it [05:13,  6.53it/s]
51it [00:08,  6.01it/s]

Epoch: [3/5], 正在训练：[1600/57773] 个样本
Val: 93.750%


101it [00:16,  6.52it/s]

Epoch: [3/5], 正在训练：[3200/57773] 个样本
Val: 96.875%


151it [00:25,  5.88it/s]

Epoch: [3/5], 正在训练：[4800/57773] 个样本
Val: 93.750%


200it [00:34,  5.70it/s]

Epoch: [3/5], 正在训练：[6400/57773] 个样本
Val: 87.500%


250it [00:44,  4.57it/s]

Epoch: [3/5], 正在训练：[8000/57773] 个样本
Val: 90.625%


300it [00:53,  4.61it/s]

Epoch: [3/5], 正在训练：[9600/57773] 个样本
Val: 87.500%


350it [01:01,  5.43it/s]

Epoch: [3/5], 正在训练：[11200/57773] 个样本
Val: 93.750%


401it [01:10,  5.67it/s]

Epoch: [3/5], 正在训练：[12800/57773] 个样本
Val: 87.500%


450it [01:18,  5.85it/s]

Epoch: [3/5], 正在训练：[14400/57773] 个样本
Val: 96.875%


501it [01:27,  6.11it/s]

Epoch: [3/5], 正在训练：[16000/57773] 个样本
Val: 96.875%


550it [01:35,  5.49it/s]

Epoch: [3/5], 正在训练：[17600/57773] 个样本
Val: 96.875%


601it [01:44,  5.20it/s]

Epoch: [3/5], 正在训练：[19200/57773] 个样本
Val: 100.000%


651it [01:53,  6.47it/s]

Epoch: [3/5], 正在训练：[20800/57773] 个样本
Val: 87.500%


700it [02:02,  4.74it/s]

Epoch: [3/5], 正在训练：[22400/57773] 个样本
Val: 84.375%


751it [02:11,  6.51it/s]

Epoch: [3/5], 正在训练：[24000/57773] 个样本
Val: 93.750%


800it [02:20,  4.91it/s]

Epoch: [3/5], 正在训练：[25600/57773] 个样本
Val: 90.625%


850it [02:30,  5.30it/s]

Epoch: [3/5], 正在训练：[27200/57773] 个样本
Val: 90.625%


900it [02:40,  4.82it/s]

Epoch: [3/5], 正在训练：[28800/57773] 个样本
Val: 100.000%


951it [02:50,  6.25it/s]

Epoch: [3/5], 正在训练：[30400/57773] 个样本
Val: 90.625%


1001it [02:58,  6.05it/s]

Epoch: [3/5], 正在训练：[32000/57773] 个样本
Val: 96.875%


1050it [03:07,  4.87it/s]

Epoch: [3/5], 正在训练：[33600/57773] 个样本
Val: 90.625%


1101it [03:17,  5.91it/s]

Epoch: [3/5], 正在训练：[35200/57773] 个样本
Val: 90.625%


1151it [03:27,  4.90it/s]

Epoch: [3/5], 正在训练：[36800/57773] 个样本
Val: 90.625%


1201it [03:36,  5.40it/s]

Epoch: [3/5], 正在训练：[38400/57773] 个样本
Val: 90.625%


1251it [03:44,  6.30it/s]

Epoch: [3/5], 正在训练：[40000/57773] 个样本
Val: 93.750%


1301it [03:52,  4.97it/s]

Epoch: [3/5], 正在训练：[41600/57773] 个样本
Val: 90.625%


1351it [04:00,  6.05it/s]

Epoch: [3/5], 正在训练：[43200/57773] 个样本
Val: 93.750%


1401it [04:09,  6.46it/s]

Epoch: [3/5], 正在训练：[44800/57773] 个样本
Val: 90.625%


1450it [04:17,  6.50it/s]

Epoch: [3/5], 正在训练：[46400/57773] 个样本
Val: 96.875%


1500it [04:26,  6.10it/s]

Epoch: [3/5], 正在训练：[48000/57773] 个样本
Val: 87.500%


1551it [04:35,  6.35it/s]

Epoch: [3/5], 正在训练：[49600/57773] 个样本
Val: 93.750%


1601it [04:43,  5.33it/s]

Epoch: [3/5], 正在训练：[51200/57773] 个样本
Val: 100.000%


1651it [04:53,  5.27it/s]

Epoch: [3/5], 正在训练：[52800/57773] 个样本
Val: 90.625%


1701it [05:01,  6.84it/s]

Epoch: [3/5], 正在训练：[54400/57773] 个样本
Val: 84.375%


1751it [05:08,  6.78it/s]

Epoch: [3/5], 正在训练：[56000/57773] 个样本
Val: 93.750%


1801it [05:18,  6.22it/s]

Epoch: [3/5], 正在训练：[57600/57773] 个样本
Val: 90.625%


1806it [05:18,  6.61it/s]
51it [00:08,  6.59it/s]

Epoch: [4/5], 正在训练：[1600/57773] 个样本
Val: 93.750%


100it [00:18,  5.04it/s]

Epoch: [4/5], 正在训练：[3200/57773] 个样本
Val: 96.875%


151it [00:29,  4.51it/s]

Epoch: [4/5], 正在训练：[4800/57773] 个样本
Val: 93.750%


201it [00:40,  5.32it/s]

Epoch: [4/5], 正在训练：[6400/57773] 个样本
Val: 93.750%


250it [00:51,  5.02it/s]

Epoch: [4/5], 正在训练：[8000/57773] 个样本
Val: 96.875%


300it [01:00,  4.11it/s]

Epoch: [4/5], 正在训练：[9600/57773] 个样本
Val: 93.750%


350it [01:09,  5.52it/s]

Epoch: [4/5], 正在训练：[11200/57773] 个样本
Val: 90.625%


401it [01:20,  5.56it/s]

Epoch: [4/5], 正在训练：[12800/57773] 个样本
Val: 93.750%


451it [01:28,  5.66it/s]

Epoch: [4/5], 正在训练：[14400/57773] 个样本
Val: 93.750%


500it [01:38,  4.28it/s]

Epoch: [4/5], 正在训练：[16000/57773] 个样本
Val: 93.750%


551it [01:48,  5.68it/s]

Epoch: [4/5], 正在训练：[17600/57773] 个样本
Val: 93.750%


601it [01:57,  6.23it/s]

Epoch: [4/5], 正在训练：[19200/57773] 个样本
Val: 90.625%


651it [02:04,  6.98it/s]

Epoch: [4/5], 正在训练：[20800/57773] 个样本
Val: 93.750%


701it [02:12,  6.61it/s]

Epoch: [4/5], 正在训练：[22400/57773] 个样本
Val: 84.375%


750it [02:21,  5.38it/s]

Epoch: [4/5], 正在训练：[24000/57773] 个样本
Val: 93.750%


800it [02:30,  5.18it/s]

Epoch: [4/5], 正在训练：[25600/57773] 个样本
Val: 96.875%


850it [02:40,  5.33it/s]

Epoch: [4/5], 正在训练：[27200/57773] 个样本
Val: 84.375%


901it [02:48,  6.80it/s]

Epoch: [4/5], 正在训练：[28800/57773] 个样本
Val: 96.875%


950it [02:56,  5.27it/s]

Epoch: [4/5], 正在训练：[30400/57773] 个样本
Val: 96.875%


1000it [03:06,  6.77it/s]

Epoch: [4/5], 正在训练：[32000/57773] 个样本
Val: 84.375%


1051it [03:14,  5.14it/s]

Epoch: [4/5], 正在训练：[33600/57773] 个样本
Val: 90.625%


1101it [03:23,  6.34it/s]

Epoch: [4/5], 正在训练：[35200/57773] 个样本
Val: 90.625%


1151it [03:32,  6.51it/s]

Epoch: [4/5], 正在训练：[36800/57773] 个样本
Val: 93.750%


1201it [03:42,  5.57it/s]

Epoch: [4/5], 正在训练：[38400/57773] 个样本
Val: 96.875%


1251it [03:50,  5.79it/s]

Epoch: [4/5], 正在训练：[40000/57773] 个样本
Val: 93.750%


1301it [03:58,  6.57it/s]

Epoch: [4/5], 正在训练：[41600/57773] 个样本
Val: 96.875%


1351it [04:07,  5.88it/s]

Epoch: [4/5], 正在训练：[43200/57773] 个样本
Val: 93.750%


1401it [04:16,  5.91it/s]

Epoch: [4/5], 正在训练：[44800/57773] 个样本
Val: 100.000%


1451it [04:26,  5.80it/s]

Epoch: [4/5], 正在训练：[46400/57773] 个样本
Val: 87.500%


1501it [04:34,  7.23it/s]

Epoch: [4/5], 正在训练：[48000/57773] 个样本
Val: 84.375%


1551it [04:42,  5.26it/s]

Epoch: [4/5], 正在训练：[49600/57773] 个样本
Val: 87.500%


1600it [04:51,  4.95it/s]

Epoch: [4/5], 正在训练：[51200/57773] 个样本
Val: 93.750%


1651it [05:01,  5.05it/s]

Epoch: [4/5], 正在训练：[52800/57773] 个样本
Val: 93.750%


1701it [05:09,  6.71it/s]

Epoch: [4/5], 正在训练：[54400/57773] 个样本
Val: 90.625%


1750it [05:16,  5.00it/s]

Epoch: [4/5], 正在训练：[56000/57773] 个样本
Val: 90.625%


1801it [05:27,  6.18it/s]

Epoch: [4/5], 正在训练：[57600/57773] 个样本
Val: 96.875%


1806it [05:28,  5.17it/s]
51it [00:09,  7.33it/s]

Epoch: [5/5], 正在训练：[1600/57773] 个样本
Val: 96.875%


100it [00:18,  4.50it/s]

Epoch: [5/5], 正在训练：[3200/57773] 个样本
Val: 100.000%


151it [00:26,  5.80it/s]

Epoch: [5/5], 正在训练：[4800/57773] 个样本
Val: 96.875%


201it [00:35,  5.68it/s]

Epoch: [5/5], 正在训练：[6400/57773] 个样本
Val: 100.000%


251it [00:43,  6.52it/s]

Epoch: [5/5], 正在训练：[8000/57773] 个样本
Val: 93.750%


300it [00:52,  4.51it/s]

Epoch: [5/5], 正在训练：[9600/57773] 个样本
Val: 96.875%


351it [01:01,  5.45it/s]

Epoch: [5/5], 正在训练：[11200/57773] 个样本
Val: 96.875%


400it [01:10,  4.75it/s]

Epoch: [5/5], 正在训练：[12800/57773] 个样本
Val: 96.875%


450it [01:18,  5.02it/s]

Epoch: [5/5], 正在训练：[14400/57773] 个样本
Val: 96.875%


500it [01:28,  5.01it/s]

Epoch: [5/5], 正在训练：[16000/57773] 个样本
Val: 96.875%


550it [01:38,  4.62it/s]

Epoch: [5/5], 正在训练：[17600/57773] 个样本
Val: 100.000%


601it [01:47,  4.90it/s]

Epoch: [5/5], 正在训练：[19200/57773] 个样本
Val: 90.625%


651it [01:58,  5.49it/s]

Epoch: [5/5], 正在训练：[20800/57773] 个样本
Val: 93.750%


700it [02:07,  5.00it/s]

Epoch: [5/5], 正在训练：[22400/57773] 个样本
Val: 96.875%


751it [02:16,  6.31it/s]

Epoch: [5/5], 正在训练：[24000/57773] 个样本
Val: 93.750%


800it [02:25,  5.32it/s]

Epoch: [5/5], 正在训练：[25600/57773] 个样本
Val: 96.875%


850it [02:38,  4.13it/s]

Epoch: [5/5], 正在训练：[27200/57773] 个样本
Val: 96.875%


900it [02:48,  5.09it/s]

Epoch: [5/5], 正在训练：[28800/57773] 个样本
Val: 87.500%


950it [02:59,  4.23it/s]

Epoch: [5/5], 正在训练：[30400/57773] 个样本
Val: 93.750%


1000it [03:10,  4.36it/s]

Epoch: [5/5], 正在训练：[32000/57773] 个样本
Val: 100.000%


1050it [03:22,  4.59it/s]

Epoch: [5/5], 正在训练：[33600/57773] 个样本
Val: 96.875%


1100it [03:34,  4.13it/s]

Epoch: [5/5], 正在训练：[35200/57773] 个样本
Val: 93.750%


1150it [03:46,  4.21it/s]

Epoch: [5/5], 正在训练：[36800/57773] 个样本
Val: 96.875%


1201it [03:57,  4.93it/s]

Epoch: [5/5], 正在训练：[38400/57773] 个样本
Val: 90.625%


1250it [04:08,  4.46it/s]

Epoch: [5/5], 正在训练：[40000/57773] 个样本
Val: 96.875%


1301it [04:19,  6.56it/s]

Epoch: [5/5], 正在训练：[41600/57773] 个样本
Val: 100.000%


1351it [04:28,  5.16it/s]

Epoch: [5/5], 正在训练：[43200/57773] 个样本
Val: 96.875%


1401it [04:36,  5.20it/s]

Epoch: [5/5], 正在训练：[44800/57773] 个样本
Val: 96.875%


1451it [04:45,  6.52it/s]

Epoch: [5/5], 正在训练：[46400/57773] 个样本
Val: 100.000%


1500it [04:55,  4.99it/s]

Epoch: [5/5], 正在训练：[48000/57773] 个样本
Val: 96.875%


1550it [05:05,  4.75it/s]

Epoch: [5/5], 正在训练：[49600/57773] 个样本
Val: 87.500%


1600it [05:13,  5.69it/s]

Epoch: [5/5], 正在训练：[51200/57773] 个样本
Val: 93.750%


1651it [05:21,  6.21it/s]

Epoch: [5/5], 正在训练：[52800/57773] 个样本
Val: 90.625%


1701it [05:29,  6.44it/s]

Epoch: [5/5], 正在训练：[54400/57773] 个样本
Val: 100.000%


1750it [05:38,  4.77it/s]

Epoch: [5/5], 正在训练：[56000/57773] 个样本
Val: 93.750%


1801it [05:47,  5.71it/s]

Epoch: [5/5], 正在训练：[57600/57773] 个样本
Val: 93.750%


1806it [05:48,  6.49it/s]


In [23]:
'''
训练集 Accuracy
'''
with torch.no_grad():
    correct = 0
    
    for _, (x_, y_) in enumerate(train_dl):
        out_ = model(x_.to(device))
        correct += torch.sum(torch.argmax(out_, dim=1) == y_.to(device)).item()
        
    print('训练集 Acc: {:.3f}%'.format(correct / len(train_x) * 100))

'''
验证集 Accuracy
'''
with torch.no_grad():
    correct = 0
    
    for _, (x_, y_) in enumerate(val_dl):
        out_ = model(x_.to(device))
        correct += torch.sum(torch.argmax(out_, dim=1) == y_.to(device)).item()
        
    print('验证集 Acc: {:.3f}%'.format(correct / len(val_x) * 100))

训练集 Acc: 96.057%
验证集 Acc: 91.080%
