In [None]:
# mnist 다운받을 때 필요
!pip install wget

# master 브랜치를 가져옴
!git clone https://github.com/you-just-want-attention/all-about-mnist.git

In [1]:
import sys
#sys.path.append("./all-about-mnist/utils/")
sys.path.append("../../utils/")
from dataset import CalculationDataset

import numpy as np
import torch
import torch.nn as nn
import pandas as pd
from tqdm import tqdm

dtype = torch.FloatTensor

## 1) Hyper Parameter Setting

In [9]:
# S: Symbol that shows starting of decoding input
# E: Symbol that shows starting of decoding output
# P: Symbol that will fill in blank sequence if current batch data size is short than time steps
char_arr = [c for c in 'SEP1234567890(){}[]+-*/'] # embedding table
char_dict = {i:char for i,char in enumerate(char_arr)}

# 모델에 대한 Setting
N_CLASS = len(char_arr)
N_HIDDEN = 128
BATCH_SIZE = 32

# 데이터셋에 대한 Setting
num_digit = 3
n_step = int(num_digit + (num_digit - 1) + (num_digit//2)*2 + 2)

## 2) 데이터 셋 제너레이터 만들기

In [3]:
class BatchGenerator:
    def __init__(self, dataset, char_arr, batch_size=32, n_step=30, shuffle=True):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.n_step = n_step
        self.char_arr = char_arr
        self.num_dic = {n: i for i, n in enumerate(char_arr)}
        self.n_class = len(self.num_dic)
        
        self.on_epoch_end()
        
    def __len__(self):
        'Denotes the number of batches per epoch'
        return len(self.dataset) // self.batch_size
    
    def __getitem__(self, index):
        _, eq_results, _, equations = self.dataset[self.batch_size * index:
                                                   self.batch_size * (index+1)]
        eq_results = eq_results.astype(int).astype(str)
        seq_data = np.stack([equations,eq_results],axis=-1)
        input_batch, output_batch, target_batch = self.make_batch(seq_data)
        return input_batch, output_batch, target_batch
    
    def make_batch(self, seq_data):
        input_batch, output_batch, target_batch = [], [], []

        for idx, (equation, result) in enumerate(seq_data):
            # input은 패딩, target은 한 후 padding 지점 앞 marking
            if self.n_step < len(equation):
                raise ValueError("n_Step이 너무 작습니다. 더 큰값으로 설정해주세요")
            
            equation = equation + "P" * (self.n_step-len(equation))
            result = result + "E" + "P" * (self.n_step-len(result)-1)

            input_data = [self.num_dic[n] for n in equation]
            output_data = [self.num_dic[n] for n in ('S' + result[:-1])]
            target_data = [self.num_dic[n] for n in result]

            input_batch.append(np.eye(self.n_class)[input_data])
            output_batch.append(np.eye(self.n_class)[output_data])
            # one-hot으로 들어가는 것 아님
            target_batch.append(target_data)

        # make tensor
        return torch.Tensor(input_batch), torch.Tensor(output_batch), torch.LongTensor(target_batch)     
        
    def on_epoch_end(self):
        'Updates indexes after each epoch'
        if self.shuffle:
            self.dataset.shuffle()


In [4]:
# 학습 데이터셋에 대한 Generator
train_set = CalculationDataset('train', digit=num_digit,)
traingen = BatchGenerator(train_set, char_arr, BATCH_SIZE, n_step=n_step)

# 테스트 데이터셋에 대한 Generator
test_set = CalculationDataset('train', digit=num_digit,)
testgen = BatchGenerator(test_set, char_arr, BATCH_SIZE, n_step=n_step)

In [5]:
# 데이터셋이 올바른지 확인해보기
input_batch, output_batch, target_batch = traingen[0]

print("data size : {}".format(len(input_batch)))
print("padded sentence size : {}".format(len(input_batch[0])))
print("vocabulary size : {}".format(len(input_batch[0][0])))

data size : 32
padded sentence size : 9
vocabulary size : 23


## 3) Decoder 구현

In [6]:
def decode_data(data):
    """
    Encoding 된 data를 다시 문자로 바꾸어주는 역할을 함
    """
    if isinstance(data,torch.Tensor):
        data = data.detach().numpy()
    if np.ndim(data) == 2:
        data = np.expand_dims(data,axis=0)
    return (
    pd.DataFrame(data.argmax(axis=-1))
       .applymap(lambda x : char_dict[x])
       .apply(lambda x : x.sum(),axis=1)
       .apply(lambda x : x.replace("S",""))
       .apply(lambda x : x.replace("E",""))        
       .apply(lambda x : x.replace("P",""))                
       .values
       .tolist()
    )

In [10]:
decode_data(output_batch)

['6',
 '0',
 '1',
 '5',
 '5',
 '1',
 '0',
 '74',
 '0',
 '2',
 '7',
 '9',
 '16',
 '0',
 '294',
 '-27',
 '-2',
 '9',
 '3',
 '-1',
 '0',
 '-1',
 '0',
 '0',
 '1',
 '18',
 '36',
 '0',
 '12',
 '10',
 '-3',
 '-1']

In [12]:
decode_data(input_batch)

['8-7/(4)',
 '1/(9*(6))',
 '8-(7)-(0)',
 '0*5+(5)',
 '7-7+(5)',
 '4+((4-7))',
 '(8)/7/2',
 '(8)*9+2',
 '0*9/(8)',
 '4*(3/6)',
 '4+(0)+(3)',
 '(1)*1*(9)',
 '9-(0)+7',
 '7*9*(0)',
 '7*6*(7)',
 '1-4*(7)',
 '(6/(4-7))',
 '(9+(1)/4)',
 '(3)-0/4',
 '(6-(7))-0',
 '3/(9-3)',
 '1-((1))-1',
 '9-9*(1)',
 '0+1/((7))',
 '1-(6)+6',
 '7*2+(4)',
 '7*(4)+8',
 '(0)+(7)/8',
 '6*2-(0)',
 '(1)+4+(5)',
 '5-(1)*(8)',
 '7-2-(6)']

## 4) model 구성하기

In [13]:
class Seq2Seq(nn.Module):
    def __init__(self):
        super(Seq2Seq, self).__init__()
        # PyTorch RNN class implements the Elman(vanilla) RNN
        self.enc_cell = nn.RNN(input_size=N_CLASS, 
                               hidden_size=N_HIDDEN,
                               dropout=0.5)
        self.dec_cell = nn.RNN(input_size=N_CLASS, 
                               hidden_size=N_HIDDEN,
                               dropout=0.5)
        self.fc = nn.Linear(N_HIDDEN, N_CLASS)

    def forward(self, enc_input, enc_hidden, dec_input):
        enc_input = enc_input.transpose(0, 1)
        dec_input = dec_input.transpose(0, 1)

        _, enc_states = self.enc_cell(enc_input,
                                      enc_hidden)
        outputs, _ = self.dec_cell(dec_input,
                                   enc_states)

        model = self.fc(outputs)
        return model

In [14]:
encoding_cell = nn.RNN(input_size=N_CLASS,
                       hidden_size=N_HIDDEN,
                       dropout=0.5)

  "num_layers={}".format(dropout, num_layers))


## 4) 모델 학습하기

In [None]:
model = Seq2Seq()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

num_epoch = 100 # Epoch 수
for epoch in range(num_epoch):
    for step in tqdm(range(len(traingen))):
        input_batch, output_batch, target_batch = traingen[step]
        
        # hidden value 초기화
        hidden = torch.zeros(1, BATCH_SIZE, N_HIDDEN)
        # 이전 interation에서 축적 됐을 수도 있으니 값을 초기화.
        optimizer.zero_grad()
        # 학습 데이터 모델에 전달(input_batch, hidden, output_batch)
        output = model(input_batch, hidden, output_batch)
        output = output.transpose(0, 1)
        loss = 0
        # 결과와 정답 .... and element 각각 비교
        for i in range(0, len(target_batch)):
            loss += criterion(output[i], target_batch[i])
        # propagages the loss value back through the network. 
        loss.backward()
        # 미분한거 업데이트 하는 function.
        optimizer.step()
    print("--------------------------")
    print('Epoch : {:2d} | cost {:.3f}'.format(epoch+1,loss))
    print("--------------------------")
    # Sample Dataset 평가
    print("Test------------")
    input_batch, output_batch, target_batch = testgen[0]

    hidden = torch.zeros(1, len(input_batch), N_HIDDEN) 
    pred_batch = model(input_batch, hidden, output_batch)

    decoded_input_batch = decode_data(input_batch)
    decoded_output_batch = decode_data(output_batch)
    decoded_pred_batch = decode_data(pred_batch)

    for idx, (input_str, answer, prediction) in enumerate(
        zip(decoded_input_batch, decoded_output_batch,decoded_pred_batch)):
        print("{}th test case : ".format(idx))
        print(input_str,"->")
        print("answer : ",answer)
        print("prediction: ",prediction)

        print("\n")

100%|██████████| 572/572 [00:33<00:00, 17.28it/s]
  0%|          | 0/572 [00:00<?, ?it/s]

--------------------------
Epoch :  1 | cost 17.472
--------------------------
Test------------
0th test case : 
1+((2/4)) ->
answer :  1
prediction:  11111111111111111111111111111111


1th test case : 
(4*9)-3 ->
answer :  33
prediction:  


2th test case : 
7-9+((6)) ->
answer :  4
prediction:  


3th test case : 
7+(4*(9)) ->
answer :  43
prediction:  


4th test case : 
9+4*(1) ->
answer :  13
prediction:  


5th test case : 
9/5+(3) ->
answer :  4
prediction:  


6th test case : 
1*(2)+(7) ->
answer :  9
prediction:  


7th test case : 
1+0-(9) ->
answer :  -8
prediction:  


8th test case : 
(6+(8)-0) ->
answer :  14
prediction:  




 67%|██████▋   | 382/572 [00:22<00:11, 16.69it/s]