In [23]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt

In [24]:
# Hyper Parameters
EPOCH = 1
BATCH_SIZE = 64
TIME_STEP=28
INPUT_SIZE=28
LR = 0.01
DOWNLOAD_MNIST = False

In [25]:
train_data = torchvision.datasets.MNIST(
    root = './mnist',
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=DOWNLOAD_MNIST
)

In [30]:
type(train_data)

torchvision.datasets.mnist.MNIST

In [28]:
train_data.data.shape

torch.Size([60000, 28, 28])

In [29]:
train_data.targets.shape

torch.Size([60000])

In [5]:
train_loader = Data.DataLoader(dataset=train_data, 
                               batch_size=BATCH_SIZE, 
                               shuffle=True,
                               num_workers=2
                              )

In [7]:
test_data = torchvision.datasets.MNIST(
    root = './mnist',
    train=False,
    transform=torchvision.transforms.ToTensor()
)

In [8]:
test_x = Variable(test_data.data).type(torch.FloatTensor)[:2000]/255.
test_y = test_data.targets.numpy().squeeze()[:2000]

In [12]:
class RNN(nn.Module):
    def __init__(self):
        super(RNN, self).__init__()
        self.rnn = nn.LSTM(
            input_size=INPUT_SIZE,
            hidden_size=64,
            num_layers=1,
            batch_first=True, #数据的第一个维度是否是batch
        )
        self.out = nn.Linear(64, 10)
        
    def forward(self, x):
        r_out, (h_n, h_c) = self.rnn(x, None)
        out = self.out(r_out[:,-1,:]) # (batch, time_step, input)
        return out

In [13]:
rnn = RNN()

In [14]:
rnn

RNN(
  (rnn): LSTM(28, 64, batch_first=True)
  (out): Linear(in_features=64, out_features=10, bias=True)
)

In [15]:
optimizer = torch.optim.Adam(rnn.parameters(), lr=LR)
loss_func = nn.CrossEntropyLoss()

In [19]:
for epoch in range(EPOCH):
    for step, (x, y) in enumerate(train_loader):
        b_x = Variable(x.view(-1, 28,28))
        b_y = Variable(y)
        
        output = rnn(b_x)
        loss = loss_func(output, b_y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if step % 50 == 0:
            test_output = rnn(test_x)
#             pred_y = torch.max(nn.functional.softmax(test_output), 1)[1].data.numpy().squeeze()
            pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
            accuracy = sum(pred_y == test_y) / test_y.size
            print("Epoch: ", epoch, "| train loss: %.4f" % loss.item(), "| test accuracy: %.4f" % accuracy)

Epoch:  0 | train loss: 2.2705 | test accuracy: 0.1695
Epoch:  0 | train loss: 1.5417 | test accuracy: 0.5470
Epoch:  0 | train loss: 1.0287 | test accuracy: 0.6255
Epoch:  0 | train loss: 0.7693 | test accuracy: 0.7240
Epoch:  0 | train loss: 0.3347 | test accuracy: 0.7730
Epoch:  0 | train loss: 0.3104 | test accuracy: 0.8495
Epoch:  0 | train loss: 0.4609 | test accuracy: 0.8755
Epoch:  0 | train loss: 0.2363 | test accuracy: 0.8675
Epoch:  0 | train loss: 0.2929 | test accuracy: 0.9115
Epoch:  0 | train loss: 0.1960 | test accuracy: 0.9155
Epoch:  0 | train loss: 0.2718 | test accuracy: 0.8915
Epoch:  0 | train loss: 0.3640 | test accuracy: 0.9300
Epoch:  0 | train loss: 0.1673 | test accuracy: 0.9295
Epoch:  0 | train loss: 0.1372 | test accuracy: 0.9430
Epoch:  0 | train loss: 0.2585 | test accuracy: 0.9415
Epoch:  0 | train loss: 0.2424 | test accuracy: 0.9430
Epoch:  0 | train loss: 0.1220 | test accuracy: 0.9445
Epoch:  0 | train loss: 0.0456 | test accuracy: 0.9360
Epoch:  0 

In [22]:
test_output = rnn(test_x[:10].view(-1,28,28))
pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
print(pred_y, ' prediction number')
print(test_y[:10], ' real number')

[7 2 1 0 4 1 4 9 8 9]  prediction number
[7 2 1 0 4 1 4 9 5 9]  real number
