In [51]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt

In [52]:
# Hyper Parameters
EPOCH = 1
BATCH_SIZE = 64
TIME_STEP=28
INPUT_SIZE=28
LR = 0.01
DOWNLOAD_MNIST = False

In [53]:
train_data = torchvision.datasets.MNIST(
    root = './mnist',
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=DOWNLOAD_MNIST
)

In [54]:
type(train_data)

torchvision.datasets.mnist.MNIST

In [55]:
train_data.data.shape

torch.Size([60000, 28, 28])

In [56]:
train_data.targets.shape

torch.Size([60000])

In [57]:
train_loader = Data.DataLoader(dataset=train_data, 
                               batch_size=BATCH_SIZE, 
                               shuffle=True,
                               num_workers=2
                              )

In [58]:
test_data = torchvision.datasets.MNIST(
    root = './mnist',
    train=False,
    transform=torchvision.transforms.ToTensor()
)

In [59]:
test_x = Variable(test_data.data).type(torch.FloatTensor)[:2000]/255.
test_y = test_data.targets.numpy().squeeze()[:2000]

In [80]:
class RNN(nn.Module):
    def __init__(self):
        super(RNN, self).__init__()
        self.rnn = nn.LSTM(
            input_size=INPUT_SIZE,
            hidden_size=32,
            num_layers=2,
            batch_first=True, #数据的第一个维度是否是batch
        )
        self.out = nn.Linear(32, 10)
        
    def forward(self, x):
        print("x: " ,x.shape)
        r_out, (h_n, h_c) = self.rnn(x, None)
        print("r_out: ", r_out.shape)
        print("h_n: ", h_n.shape)
        print("h_c: ", h_c.shape)
        out = self.out(r_out[:,-1,:]) # (batch, time_step, input)
        print("out: ", out.shape)
        return out

In [81]:
rnn = RNN()

In [82]:
rnn

RNN(
  (rnn): LSTM(28, 32, num_layers=2, batch_first=True)
  (out): Linear(in_features=32, out_features=10, bias=True)
)

In [83]:
optimizer = torch.optim.Adam(rnn.parameters(), lr=LR)
loss_func = nn.CrossEntropyLoss()

In [84]:
for epoch in range(EPOCH):
    for step, (x, y) in enumerate(train_loader):
        b_x = Variable(x.view(-1, 28,28))
        b_y = Variable(y)
        
        output = rnn(b_x)
        loss = loss_func(output, b_y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if step % 50 == 0:
            test_output = rnn(test_x)
#             pred_y = torch.max(nn.functional.softmax(test_output), 1)[1].data.numpy().squeeze()
            pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
            accuracy = sum(pred_y == test_y) / test_y.size
            print("Epoch: ", epoch, "| train loss: %.4f" % loss.item(), "| test accuracy: %.4f" % accuracy)

x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([2000, 28, 28])
r_out:  torch.Size([2000, 28, 32])
h_n:  torch.Size([2, 2000, 32])
h_c:  torch.Size([2, 2000, 32])
out:  torch.Size([2000, 10])
Epoch:  0 | train loss: 2.3009 | test accuracy: 0.0980
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torc

x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Si

r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.S

r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.S

r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.S

x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Si

x:  torch.Size([2000, 28, 28])
r_out:  torch.Size([2000, 28, 32])
h_n:  torch.Size([2, 2000, 32])
h_c:  torch.Size([2, 2000, 32])
out:  torch.Size([2000, 10])
Epoch:  0 | train loss: 0.4477 | test accuracy: 0.9065
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torc

r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.S

x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Si

r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.S

r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.S

x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Si

x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Si

x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Si

x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Si

x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Size([2, 64, 32])
out:  torch.Size([64, 10])
x:  torch.Size([64, 28, 28])
r_out:  torch.Size([64, 28, 32])
h_n:  torch.Size([2, 64, 32])
h_c:  torch.Si

In [22]:
test_output = rnn(test_x[:10].view(-1,28,28))
pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
print(pred_y, ' prediction number')
print(test_y[:10], ' real number')

[7 2 1 0 4 1 4 9 8 9]  prediction number
[7 2 1 0 4 1 4 9 5 9]  real number
