# Rede Neural Recorrente simples para classificar dígitos MNIST

Este experimento ilustra o uso de uma rede neural recorrente na tarefa de classificação
de dígitos manuscritos do dataset MNIST.

A imagem é estruturada em 28 sequências de 28 pixels cada. Cada elemento da sequência é
formado por uma linha da imagem. Cada linha da imagem contém 28 atributos (pixels).

<img src='../figures/RNN_MNIST.png', width= 700></img>

**Obs:** Este experimento foi inspirado no artigo:
"A Simple Way to Initialize Recurrent Networks of Rectified Linear Units"
by Quoc V. Le, Navdeep Jaitly, Geoffrey E. Hinton
arxiv:1504.00941v2 [cs.NE] 7 Apr 2015
http://arxiv.org/pdf/1504.00941v2.pdf

A principal modificação é a formatação da imagem em 28 sequências de 28 pixels.

## Importação

In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import os

import torch
import torch.nn as nn
from   torch.autograd import Variable

import torchvision

import lib.pytorch_trainer as ptt
use_gpu = torch.cuda.is_available()
print('GPU available:', use_gpu)

GPU available: True


In [2]:
print(''.join(os.popen('free -th -m').readlines()))


              total        used        free      shared  buff/cache   available
Mem:            15G        3.4G        5.3G        181M        6.9G         11G
Swap:           15G        912M         14G
Total:          31G        4.3G         20G



## Leitura do dataset

In [3]:
dataset_dir = '/data/datasets/MNIST/'

data_train, target_train = torch.load(dataset_dir + 'processed/training.pt')
data_test,  target_test  = torch.load(dataset_dir + 'processed/test.pt')

data_train = data_train.float() / 255.
data_test  = data_test.float() / 255.

In [4]:
print(''.join(os.popen('free -th -m').readlines()))


              total        used        free      shared  buff/cache   available
Mem:            15G        3.6G        5.1G        181M        6.9G         11G
Swap:           15G        912M         14G
Total:          31G        4.5G         20G



## Pouquíssimas amostras - depurando apenas

In [5]:
if True:
    n_samples_train = 1000
    n_samples_test  = 500
else:
    n_samples_train = data_train.size(0)
    n_samples_test  = data_test.size(0)
    
x_train = data_train[:n_samples_train].clone()
y_train = target_train[:n_samples_train].clone()
x_test  = data_test[:n_samples_test].clone()
y_test  = target_test[:n_samples_test].clone()
del data_train, target_train, data_test, target_test

In [6]:
x_train.shape

torch.Size([1000, 28, 28])

In [7]:
print(''.join(os.popen('free -th -m').readlines()))

              total        used        free      shared  buff/cache   available
Mem:            15G        3.4G        5.3G        181M        6.9G         11G
Swap:           15G        912M         14G
Total:          31G        4.3G         20G



## SimpleRNN com 100 neurônios

In [8]:
class Model_RNN(nn.Module):
    def __init__(self, hidden_size):
        super(Model_RNN, self).__init__()
        self.rnn = nn.RNN(28, hidden_size, 1, batch_first=True, nonlinearity='relu', dropout=0.05) # 28 atributos, 100 neurônios, 1 camada  
        self.out = nn.Linear(hidden_size, 10)

    def forward(self, xin):
        _,x = self.rnn(xin)
        x  = self.out(x)
        return torch.squeeze(x,dim=0)
    

model_rnn = Model_RNN(100)
if use_gpu:
    model_rnn = model_rnn.cuda()
model_rnn

Model_RNN (
  (rnn): RNN(28, 100, batch_first=True, dropout=0.05)
  (out): Linear (100 -> 10)
)

In [9]:
print(''.join(os.popen('free -th -m').readlines()))

              total        used        free      shared  buff/cache   available
Mem:            15G        4.7G        3.9G        187M        6.9G         10G
Swap:           15G        912M         14G
Total:          31G        5.6G         18G



## Predict com uma amostra

In [None]:
xin = x_train[0:1]
xin = torch.zeros((1,28,28))
print('xin.shape:',xin.shape)
xv_in = Variable(xin)
if use_gpu:
    xv_in = xv_in.cuda()
ypred = model_rnn(xv_in)
print(torch.squeeze(ypred[:,:10]))

## Criando o treinador

In [None]:
chkpt_cb = ptt.ModelCheckpoint('../../models/SimpleRNN_MNIST_t', reset=True, verbose=1)

trainer = ptt.DeepNetTrainer(model_rnn,
                        criterion = nn.CrossEntropyLoss(),
                        optimizer = torch.optim.Adam(model_rnn.parameters(),lr=1e-3),
                        callbacks = [chkpt_cb, ptt.AccuracyMetric(),ptt.PrintCallback()]
                        )

In [None]:
print(''.join(os.popen('free -th -m').readlines()))

In [None]:
trainer.fit(20, x_train, y_train, valid_data=(x_test, y_test),batch_size=32)

In [None]:
print(''.join(os.popen('free -th -m').readlines()))

In [None]:
plt.plot(trainer.metrics['valid']['acc'])

## Usando LSTM

In [None]:
class Model_LSTM(nn.Module):
    def __init__(self, hidden_size):
        super(Model_LSTM, self).__init__()
        self.rnn = nn.LSTM(28, hidden_size, 1, batch_first=True, dropout=0.05) # 28 atributos, 100 neurônios, 1 camada  
        self.out = nn.Linear(hidden_size, 10)

    def forward(self, xin):
        _,(x,_) = self.rnn(xin)
        x  = self.out(x)
        return torch.squeeze(x,dim=0)
    
model_lstm = Model_LSTM(100)
if use_gpu:
    model_lstm = model_lstm.cuda()

In [None]:
print(''.join(os.popen('free -th -m').readlines()))

### Predict com uma amostra

In [None]:
xin = x_train[0:1]
xin = torch.zeros((1,28,28))
print('xin.shape:',xin.shape)
xv_in = Variable(xin)
if use_gpu:
    xv_in = xv_in.cuda()
ypred = model_lstm(xv_in)
print(torch.squeeze(ypred[:,:10]))

### Criando o treinador

In [None]:
chkpt_cb = ptt.ModelCheckpoint('../../models/SimpleRNN_MNIST_lstm', reset=True, verbose=1)

trainer_lstm = ptt.DeepNetTrainer(
                model_lstm,
                criterion = nn.CrossEntropyLoss(),
                optimizer = torch.optim.Adam(model_lstm.parameters(),lr=1e-3),
                callbacks = [chkpt_cb, ptt.AccuracyMetric(),ptt.PrintCallback()]
                )

In [None]:
trainer_lstm.fit(20, x_train, y_train, valid_data=(x_test, y_test),batch_size=32)

In [None]:
print(''.join(os.popen('free -th -m').readlines()))

In [None]:
plt.plot(trainer_lstm.metrics['valid']['acc'])

## Comparação RNN x LSTM

    - Acurácia RNN: 97,5%
    - Acurácia LSTM: 98,9%

# Exercícios

In [None]:
1. 