# 通过LSTM实现图片分类

一般来说卷积神经网络是处理图片的能手，但这并不意味着循环神经网络不具备处理图片的能力，我们依然可以使用循环神经网络完成图片的分类任务。

In [1]:
import torch
from torch import nn
import numpy as np
from torchvision.datasets import mnist
from torch.utils.data import DataLoader
from torchvision import transforms
from datetime import datetime

In [2]:
def get_acc(output, label):
    total = output.shape[0]
    _, pred_label = output.max(1)#求每行的最大就是最有可能的类别
    num_correct = (pred_label == label).sum().float()
    return num_correct / total
# 定义数据的变换
data_tf = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize([0.5],[0.5])
                             ])

In [3]:
train_set = mnist.MNIST('./data',train=True,transform=data_tf,download=True)
test_set = mnist.MNIST('./data',train=False,transform=data_tf,download=True)
train_data = DataLoader(train_set,batch_size=128,shuffle=True)
test_data = DataLoader(test_set,batch_size=128,shuffle=True)

手写字体识别的数据集的图片都是单通道的，图片大小是28*28。我们知道循环神经网络是处理序列数据的，所以我们可以将图片看成是序列数据，`将每张图片看作是长为28的序列`，`序列中的每个元素的特征维度是28`，这样就将图片变成了一个序列。

In [4]:
class LSTM_MNIST(nn.Module):
    def __init__(self,in_dim,hidden_dim,n_layer,n_class):
        super(LSTM_MNIST,self).__init__()
        self.n_layer=n_layer
        self.hidden_dim=hidden_dim
        self.lstm=nn.LSTM(in_dim,hidden_dim,n_layer,batch_first=True)
        self.classifier=nn.Linear(hidden_dim,n_class)
    def forward(self,x):
        out,_ =self.lstm(x)
        out =out[:,-1,:] # 取最后一个时间步
        out =self.classifier(out)
        return out

In [5]:
net = LSTM_MNIST(28,50,2,10)
criterion = nn.CrossEntropyLoss()# 定义损失函数
optimizer = torch.optim.SGD(net.parameters(),1e-1)
prev_time = datetime.now()

In [6]:
print(net)

LSTM_MNIST(
  (lstm): LSTM(28, 50, num_layers=2, batch_first=True)
  (classifier): Linear(in_features=50, out_features=10, bias=True)
)


In [7]:
for epoch in range(10):
    train_loss = 0
    train_acc = 0
    net = net.train()
    for im,label in train_data:
        im = im.squeeze(1)
        output = net(im)
        loss = criterion(output ,label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.data.float()
        train_acc += get_acc(output,label)
    cur_time = datetime.now()
    h,remainder = divmod((cur_time-prev_time).seconds,3600)
    m,s = divmod(remainder,60)
    time_str = "Time %02d:%02d:%02d"%(h,m,s)
    # 验证
    valid_loss=0
    valid_acc=0
    net=net.eval()
    for im,label in test_data:
        im = im.squeeze(1)
        output =net(im)
        loss= criterion(output,label)
        valid_loss +=loss.data.float()
        valid_acc +=get_acc(output,label)
    epoch_str=("Epoch %d. Train Loss %f,Train Acc:%f,Valid Loss: %f,Valid Acc: %f ,"
                %(epoch,train_loss/len(train_data),
                train_acc /len(train_data),
                valid_loss/len(test_data),
                valid_acc /len(test_data)))
    prev_time=cur_time
    print(epoch_str+time_str)

Epoch 0. Train Loss 2.216499,Train Acc:0.178560,Valid Loss: 2.083235,Valid Acc: 0.254648 ,Time 00:00:34
Epoch 1. Train Loss 0.997705,Train Acc:0.663385,Valid Loss: 0.449262,Valid Acc: 0.865704 ,Time 00:00:49
Epoch 2. Train Loss 0.345102,Train Acc:0.894817,Valid Loss: 0.279610,Valid Acc: 0.909810 ,Time 00:00:47
Epoch 3. Train Loss 0.200307,Train Acc:0.940698,Valid Loss: 0.199526,Valid Acc: 0.936907 ,Time 00:00:48
Epoch 4. Train Loss 0.143321,Train Acc:0.957673,Valid Loss: 0.125389,Valid Acc: 0.965289 ,Time 00:00:53
Epoch 5. Train Loss 0.116572,Train Acc:0.965585,Valid Loss: 0.096243,Valid Acc: 0.974288 ,Time 00:00:54
Epoch 6. Train Loss 0.098319,Train Acc:0.970144,Valid Loss: 0.114585,Valid Acc: 0.966278 ,Time 00:00:54
Epoch 7. Train Loss 0.083342,Train Acc:0.975085,Valid Loss: 0.087382,Valid Acc: 0.974090 ,Time 00:00:54
Epoch 8. Train Loss 0.073541,Train Acc:0.977995,Valid Loss: 0.086054,Valid Acc: 0.974782 ,Time 00:00:54
Epoch 9. Train Loss 0.066122,Train Acc:0.980072,Valid Loss: 0.07