# 4.2 RNN分类

对于一张图像，RNN先读进第一行的pixel，再读第二行，以此类推……

In [2]:
import torch
from torch import nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

EPOCH = 1
BATCH_SIZE = 64
TIME_STEP = 28   # 一共输入几次 /image height
INPUT_SIZE = 28  # 每次输入几个像素点 /image width
LR = 0.01

### 1. 读取数据

In [18]:
#训练数据
train_data = dsets.MNIST(
    root='./mnist',
    train=True,
    transform=transforms.ToTensor(), #将数据值压缩为0~1
    download=False
)
train_loader = torch.utils.data.DataLoader(
    dataset=train_data,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2
)
#测试数据
test_data = dsets.MNIST(root='./mnist',train=False)
test_x = test_data.data.type(torch.FloatTensor)[:2000]/255. #shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1)
test_y = test_data.targets.numpy().squeeze()[:2000]

### 2. 建立网络

In [19]:
class RNN(nn.Module):
    def __init__(self):
        super(RNN,self).__init__()
        
        self.rnn = nn.LSTM(
            input_size=INPUT_SIZE,
            hidden_size=64,
            num_layers=1,
            batch_first=True, #把batch_size放在第一个维度，(batch, time_step, input_size)
        )
        
        self.out = nn.Linear(64,10)
        
    def forward(self,x):
        r_out,(h_n,h_c) = self.rnn(x,None) # x (batch,time_step,input_size)
        out = self.out(r_out[:,-1,:]) #(batch,最后一个time step,input)，最后一个是对整个图片的总结
        return out
    
rnn = RNN()
print(rnn)

RNN(
  (rnn): LSTM(28, 64, batch_first=True)
  (out): Linear(in_features=64, out_features=10, bias=True)
)


In [20]:
optimizer = torch.optim.Adam(rnn.parameters(),lr=LR)
loss_func = nn.CrossEntropyLoss()

### 3. 训练

In [23]:
for epoch in range(EPOCH):
    for step,(x,y) in enumerate(train_loader):
        x=x.view(-1, 28, 28) #(batch, time_step, input_size)
        output = rnn(x)
        loss = loss_func(output,y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if step % 50 == 0:
            #print(test_x.shape)
            test_output = rnn(test_x)                   # (samples, time_step, input_size)
            
            pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
            accuracy = sum(pred_y == test_y) / float(test_y.size)
            print('Epoch: ', epoch, '| train loss: %.4f' % loss.item(), '| test accuracy: %.2f' % accuracy)


Epoch:  0 | train loss: 0.7153 | test accuracy: 0.77
Epoch:  0 | train loss: 0.3608 | test accuracy: 0.84
Epoch:  0 | train loss: 0.4278 | test accuracy: 0.84
Epoch:  0 | train loss: 0.2378 | test accuracy: 0.88
Epoch:  0 | train loss: 0.2056 | test accuracy: 0.87
Epoch:  0 | train loss: 0.1539 | test accuracy: 0.91
Epoch:  0 | train loss: 0.1661 | test accuracy: 0.92
Epoch:  0 | train loss: 0.2968 | test accuracy: 0.88
Epoch:  0 | train loss: 0.2276 | test accuracy: 0.92
Epoch:  0 | train loss: 0.1013 | test accuracy: 0.93
Epoch:  0 | train loss: 0.2511 | test accuracy: 0.93
Epoch:  0 | train loss: 0.1204 | test accuracy: 0.94
Epoch:  0 | train loss: 0.1662 | test accuracy: 0.93
Epoch:  0 | train loss: 0.2616 | test accuracy: 0.94
Epoch:  0 | train loss: 0.2011 | test accuracy: 0.94
Epoch:  0 | train loss: 0.0831 | test accuracy: 0.94
Epoch:  0 | train loss: 0.0459 | test accuracy: 0.95
Epoch:  0 | train loss: 0.1259 | test accuracy: 0.95
Epoch:  0 | train loss: 0.3282 | test accuracy

### 4. 测试

In [24]:
# print 10 predictions from test data
test_output = rnn(test_x[:10].view(-1, 28, 28))
pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
print(pred_y, 'prediction number')
print(test_y[:10], 'real number')

[7 2 1 0 4 1 4 9 8 9] prediction number
[7 2 1 0 4 1 4 9 5 9] real number
