In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
from torch.autograd import Variable
from torch import Tensor
from torch.utils.data import DataLoader

import torchvision.transforms as transforms
import torchvision.datasets as dataset

import math

In [6]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cuda = True if torch.cuda.is_available() else False
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

torch.manual_seed(29)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(29)

In [7]:
transform = transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,), (1.0,))
])

In [8]:
download_root = './data'

train_dataset = dataset.MNIST(download_root, transform=transform, train=True, download=True)
val_dataset = dataset.MNIST(download_root, transform=transform, train=False, download=True)
test_dataset = dataset.MNIST(download_root, transform=transform, train=False, download=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 117410023.65it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 65727451.88it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 29437342.07it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 14487094.12it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



In [9]:
batch_size = 64
dataloaders = {}

dataloaders['train'] = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
dataloaders['val'] = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)
dataloaders['test'] = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [10]:
batch_size = 100
n_iters = 6000
num_epochs = n_iters / (len(dataloaders['train']) / batch_size)
num_epochs = int(num_epochs)

In [12]:
class LSTM_cell(nn.Module):
    def __init__(self, input_size, hidden_size, bias=True):
        super(LSTM_cell, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        self.x2h = nn.Linear(input_size, 4 * hidden_size, bias=bias) # 망각/입력/셀/출력 게이트 4개로 쪼개져서 들어간다.(chunk(4, 1)) chunk(몇개의 텐서로 나눌지, 어떤 차원으로 나눌지)
        self.h2h = nn.Linear(hidden_size, 4 * hidden_size, bias=bias)
        self.reset_parameters()

    def reset_parameters(self):
        std = 1.0 / math.sqrt(self.hidden_size)
        for w in self.parameters():
            w.data.uniform(-std, std) # -std, std 사이의 임의의 실수 생성

    def forward(self, x, hidden):
        hx, cx = hidden # hidden : 이전 cell -> hx : 은닉 상태 / cx : cell 상태
        x = x.view(-1, x.size(1)) # 입력

        gates = self.x2h(x) + self.h2h(hx) # 입력 + 이전 기억
        gates = gates.squeeze()
        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)

        ingate = F.sigmoid(ingate) # 시그모이드 적용
        forgetgate = F.sigmoid(forgetgate)# 시그모이드 적용
        cellgate = F.tanh(cellgate)# tanh 적용
        outgate = F.sigmoid(outgate)# 시그모이드 적용

        cy = torch.mul(cx, forgetgate) + torch.mul(ingate, cellgate)
        hy = torch.mul(outgate, F.tanh(cy))

        return (hy, cy)

In [13]:
class LSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, layer_dim, output_dim, bias=True):
        super(LSTM, self).__init__()

        self.hidden_dim = hidden_dim # 은닉층의 뉴런/유닛 개수
        self.layer_dim = layer_dim
        self.lstm = LSTM_cell(input_dim, hidden_dim, layer_dim)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        if torch.cuda.is_available():
            h0 = Variable(torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).cuda())
        else:
            h0 = Variable(torch.zeros(self.layer_dim, x.size(0), self.hidden_dim))

        if torch.cuda.is_available():
            c0 = Variable(torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).cuda())
        else:
            c0 = Variable(torch.zeros(self.layer_dim, x.size(0), self.hidden_dim))

        outs = []
        cn = c0[0, :, :]
        hn = h0[:, :, :]

        for seq in range(x.size(1)): # 셀 계층
            hn, cn = self.lstm(x[:, seq, :], (hn, cn))
            outs.append(hn)

        out = outs[-1].squeeze()
        out = self.fc(out)

        return out