# 通过RNN实现图片分类

一般来说卷积神经网络是处理图片的能手，但这并不意味着循环神经网络不具备处理图片的能力，我们依然可以使用循环神经网络完成图片的分类任务。

In [1]:
import torch
from torch import nn
import numpy as np
from torchvision.datasets import mnist
from torch.utils.data import DataLoader
from torchvision import transforms
from datetime import datetime

In [2]:
def get_acc(output, label):
    total = output.shape[0]
    _, pred_label = output.max(1)#求每行的最大就是最有可能的类别
    num_correct = (pred_label == label).sum().float()
    return num_correct / total

# 定义数据的变换
data_tf = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize([0.5],[0.5])
                             ])

In [3]:
train_set = mnist.MNIST('./data',train=True,transform=data_tf,download=True)
test_set = mnist.MNIST('./data',train=False,transform=data_tf,download=True)
train_data = DataLoader(train_set,batch_size=128,shuffle=True)
test_data = DataLoader(test_set,batch_size=128,shuffle=True)

手写字体识别的数据集的图片都是单通道的，图片大小是28*28。我们知道循环神经网络是处理序列数据的，所以我们可以将图片看成是序列数据，`将每张图片看作是长为28的序列`，`序列中的每个元素的特征维度是28`，这样就将图片变成了一个序列。

In [4]:
class RNN_MNIST(nn.Module):
    def __init__(self,input_size,hidden_size,num_layers,num_class):
        super(RNN_MNIST,self).__init__()
        self.rnn=nn.RNN(input_size,hidden_size,num_layers,batch_first=True)
        self.classifier=nn.Linear(hidden_size,num_class)
    def forward(self,x):
        out,_ =self.rnn(x)
        out =out[:,-1,:] # 取最后一个时间步 (batch_size, seq_length, hidden_size)
        out =self.classifier(out)
        return out

In [5]:
net = RNN_MNIST(28,50,2,10)
criterion = nn.CrossEntropyLoss()# 定义损失函数
optimizer = torch.optim.SGD(net.parameters(),0.02)
prev_time = datetime.now()

In [6]:
print(net)

RNN_MNIST(
  (rnn): RNN(28, 50, num_layers=2, batch_first=True)
  (classifier): Linear(in_features=50, out_features=10, bias=True)
)


In [7]:
for epoch in range(10):
    train_loss = 0
    train_acc = 0
    net = net.train()
    for im,label in train_data:# im,label为一批数据，也就是64个样本
        # 这里的squeeze的执行效果是什么？（提问）
        im = im.squeeze(1)
        output = net(im)
        loss = criterion(output ,label)
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.data.float()
        train_acc += get_acc(output,label)
    cur_time = datetime.now()
    h,remainder = divmod((cur_time-prev_time).seconds,3600)
    m,s = divmod(remainder,60)
    time_str = "Time %02d:%02d:%02d"%(h,m,s)
    # 验证
    valid_loss=0
    valid_acc=0
    net =net.eval()
    for im,label in test_data:
        im = im.squeeze(1)
        output =net(im)
        loss= criterion(output,label)
        valid_loss +=loss.data.float()
        valid_acc +=get_acc(output,label)
    epoch_str=("Epoch %d. Train Loss %f,Train Acc:%f,Valid Loss: %f,Valid Acc: %f ,"
                %(epoch,train_loss/len(train_data),
                train_acc /len(train_data),
                valid_loss/len(test_data),
                valid_acc /len(test_data)))
    prev_time=cur_time
    print(epoch_str+time_str)

Epoch 0. Train Loss 2.045774,Train Acc:0.264659,Valid Loss: 1.595222,Valid Acc: 0.422567 ,Time 00:00:11
Epoch 1. Train Loss 1.269390,Train Acc:0.557508,Valid Loss: 1.091728,Valid Acc: 0.624209 ,Time 00:00:13
Epoch 2. Train Loss 1.007069,Train Acc:0.647699,Valid Loss: 0.851373,Valid Acc: 0.708861 ,Time 00:00:13
Epoch 3. Train Loss 0.763946,Train Acc:0.737212,Valid Loss: 0.677069,Valid Acc: 0.763548 ,Time 00:00:13
Epoch 4. Train Loss 0.585946,Train Acc:0.806997,Valid Loss: 0.480679,Valid Acc: 0.842662 ,Time 00:00:13
Epoch 5. Train Loss 0.468323,Train Acc:0.852073,Valid Loss: 0.405420,Valid Acc: 0.869858 ,Time 00:00:12
Epoch 6. Train Loss 0.374250,Train Acc:0.890009,Valid Loss: 0.327288,Valid Acc: 0.903778 ,Time 00:00:13
Epoch 7. Train Loss 0.314906,Train Acc:0.911514,Valid Loss: 0.286701,Valid Acc: 0.920293 ,Time 00:00:13
Epoch 8. Train Loss 0.299051,Train Acc:0.915162,Valid Loss: 0.245879,Valid Acc: 0.933742 ,Time 00:00:13
Epoch 9. Train Loss 0.257961,Train Acc:0.927566,Valid Loss: 0.24