In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [13]:
# 하이퍼 파라미터 정의
SEQUENCE_LENGTH = 28
INPUT_SIZE = 28
HIDDEN_SIZE = 128
NUM_LAYERS = 2
NUM_CLASSES = 10
BATCH_SIZE = 100
NUM_EPOCHS = 1
LR = 0.01

In [4]:
train_dataset = torchvision.datasets.MNIST(root='../.data/',
                                           train=True,
                                           transform=transforms.ToTensor(),
                                           download=True)
test_dataset = torchvision.datasets.MNIST(root='../.data/',
                                           train=False,
                                           transform=transforms.ToTensor(),
                                           download=True)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                          batch_size=BATCH_SIZE,
                                          shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=BATCH_SIZE,
                                          shuffle=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../.data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../.data/MNIST/raw/train-images-idx3-ubyte.gz to ../.data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../.data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../.data/MNIST/raw/train-labels-idx1-ubyte.gz to ../.data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../.data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../.data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../.data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../.data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../.data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../.data/MNIST/raw
Processing...
Done!


In [19]:
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)
        
    def forward(self, x, step):
        if step == 0:
            print("x.shape [{}]".format(x.shape))
            print("x.size(0) [{}]".format(x.size(0)))
            
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        
        out, _ = self.lstm(x, (h0, c0))
        if step == 0:
            print("h0.shape | LSTM [{}]".format(h0.shape))
            print("c0.shape | LSTM [{}]".format(c0.shape))
            print("out.shape | LSTM [{}]".format(out.shape))
            print("out[:, -1, :] | LSTM [{}]".format(out[:, -1, :].shape))
            
        out = self.fc(out[:, -1, :])
        if step == 0:
            print("out.shape | FC [{}]".format(out.shape))
            
        return out

In [25]:
model = RNN(INPUT_SIZE, HIDDEN_SIZE, NUM_LAYERS, NUM_CLASSES)
print(model)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

total_step = len(train_loader)
for epoch in range(NUM_EPOCHS):
    for i, (images, labels) in enumerate(train_loader):
        images = images.reshape(-1, SEQUENCE_LENGTH, INPUT_SIZE).to(device)
        labels = labels.to(device)
        
        outputs = model(images, i)
        loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if i == 0:
            print("input data shape [{}]".format(images.shape))
            print("label data shape [{}]".format(labels.shape))
            print("output data shape [{}]".format(outputs.shape))
        
        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, NUM_EPOCHS, i+1, total_step, loss.item()))
            
with torch.no_grad():
    correct = 0
    total = 0
    for i, (images, labels) in enumerate(test_loader):
        images = images.reshape(-1, SEQUENCE_LENGTH, INPUT_SIZE).to(device)
        labels = labels.to(device)
        outputs = model(images, i)
        
        _, predicted = torch.max(outputs.data, 1)
        
        if i == 0:
            print("predicted.shape [{}]".format(predicted.shape))
            print("labels.shape [{}]".format(labels.shape))
            
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))        

RNN(
  (lstm): LSTM(28, 128, num_layers=2, batch_first=True)
  (fc): Linear(in_features=128, out_features=10, bias=True)
)
x.shape [torch.Size([100, 28, 28])]
x.size(0) [100]
h0.shape | LSTM [torch.Size([2, 100, 128])]
c0.shape | LSTM [torch.Size([2, 100, 128])]
out.shape | LSTM [torch.Size([100, 28, 128])]
out[:, -1, :] | LSTM [torch.Size([100, 128])]
out.shape | FC [torch.Size([100, 10])]
input data shape [torch.Size([100, 28, 28])]
label data shape [torch.Size([100])]
output data shape [torch.Size([100, 10])]
Epoch [1/1], Step [100/600], Loss: 0.5629
Epoch [1/1], Step [200/600], Loss: 0.1366
Epoch [1/1], Step [300/600], Loss: 0.2646
Epoch [1/1], Step [400/600], Loss: 0.1495
Epoch [1/1], Step [500/600], Loss: 0.1181
Epoch [1/1], Step [600/600], Loss: 0.0189
x.shape [torch.Size([100, 28, 28])]
x.size(0) [100]
h0.shape | LSTM [torch.Size([2, 100, 128])]
c0.shape | LSTM [torch.Size([2, 100, 128])]
out.shape | LSTM [torch.Size([100, 28, 128])]
out[:, -1, :] | LSTM [torch.Size([100, 128])