# 足し算をSeq2Seqで学習

In [0]:
# 辞書の作成
word2id = {str(i): i for i in range(10)}
# <pad>:パディング, <eos>:終了文字
word2id.update({"<pad>": 10, "+": 11, "<eos>": 12})
id2word = {v: k for k, v in word2id.items()}

In [3]:
word2id

{'+': 11,
 '0': 0,
 '1': 1,
 '2': 2,
 '3': 3,
 '4': 4,
 '5': 5,
 '6': 6,
 '7': 7,
 '8': 8,
 '9': 9,
 '<eos>': 12,
 '<pad>': 10}

In [0]:
import random
from sklearn.model_selection import train_test_split

def load_dataset(N=20000):
    def generate_number():
        number = [random.choice(list("0123456789")) for _ in range(random.randint(1, 3))] 
        # a <= N <= b random.randint(a, b)
        return int("".join(number))
    
    def padding(string, training=True):
        string = "{:*<7s}".format(string) if training else "{:*<6s}".format(string)
        return string.replace("*", "<pad>")
    
    # 辞書を使ってidに変換
    def transform(string, seq_len=7):
        tmp = []
        for i, c in enumerate(string):
            try:
                tmp.append(word2id[c])
            except:
                tmp += [word2id["<pad>"]] * (seq_len - i)
                break
        return tmp
    data = []
    target = []    
    for _ in range(N):
        x = generate_number()
        y = generate_number()
        z = x + y
        left = padding(str(x) + "+" + str(y))
        right = padding(str(z), training=False)
        data.append(transform(left))
        right = transform(right, seq_len=6)
        # <eos>スタート
        right = [12] + right[:5]
        right[right.index(10)] = 12
        target.append(right)
        
    return data, target

In [0]:
# 学習データの作成
data, target = load_dataset()
train_x, test_x, train_t, test_t = train_test_split(data, target, test_size=0.1)

In [23]:
train_x[0], train_t[0]

([1, 11, 3, 10, 10, 10, 10], [12, 4, 12, 10, 10, 10])

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


embedding_dim = 16
hidden_dim = 128
vocab_size = len(word2id)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, batch_size=100):
        super(Encoder, self).__init__()
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size

        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim, padding_idx=word2id["<pad>"])
        self.gru = nn.GRU(embedding_dim, hidden_dim, batch_first=True)

    def forward(self, indices):
        embedding = self.word_embeddings(indices)
        if embedding.dim() == 2:
            embedding = torch.unsqueeze(embedding, 1)
        _, state = self.gru(embedding, torch.zeros(1, self.batch_size, self.hidden_dim, device=device))
        
        return state


class Decoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, batch_size=100):
        super(Decoder, self).__init__()
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size

        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim, padding_idx=word2id["<pad>"])
        self.gru = nn.GRU(embedding_dim, hidden_dim, batch_first=True)
        self.output = nn.Linear(hidden_dim, vocab_size)

    def forward(self, index, state):
        embedding = self.word_embeddings(index)
        if embedding.dim() == 2:
            embedding = torch.unsqueeze(embedding, 1)
        gruout, state = self.gru(embedding, state)
        output = self.output(gruout)
        return output, state


encoder = Encoder(vocab_size, embedding_dim, hidden_dim).to(device)
decoder = Decoder(vocab_size, embedding_dim, hidden_dim).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=word2id["<pad>"])

# Initialize opotimizers
encoder_optimizer = optim.Adam(encoder.parameters(), lr=0.001)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=0.001)

In [19]:
from datetime import datetime
from sklearn.utils import shuffle

batch_size=100
def train2batch(data, target, batch_size=100):
    input_batch = []
    output_batch = []
    data, target = shuffle(data, target)
    
    for i in range(0, len(data), batch_size):
        input_tmp = []
        output_tmp = []
        for j in range(i, i+batch_size):
            input_tmp.append(data[j])
            output_tmp.append(target[j])
        input_batch.append(input_tmp)
        output_batch.append(output_tmp)
    return input_batch, output_batch

def get_current_time():
    return datetime.now().strftime("%Y-%m-%d %H:%M:%S")

print("Training...")
n_epoch = 100
for epoch in range(1, n_epoch+1):
    input_batch, output_batch = train2batch(train_x, train_t)
    for i in range(len(input_batch)):
        # Zero gradients
        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()
        # Prepare tensor
        inputs = torch.tensor(input_batch[i], device=device)
        outputs = torch.tensor(output_batch[i], device=device)
        # Forward pass through encoder
        encoder_hidden = encoder(inputs)
        # Create source and target
        source = outputs[:, :-1]
        target = outputs[:, 1:]
        decoder_hidden = encoder_hidden
        
        # Forward batch of sequences through decoder one time step at a time
        loss = 0
        for i in range(source.size(1)):
            decoder_output, decoder_hidden = decoder(source[:, i], decoder_hidden)
            decoder_output = torch.squeeze(decoder_output)
            loss += criterion(decoder_output, target[:, i])

        # Perform backpropagation
        loss.backward()
        
        # Adjust model weights
        encoder_optimizer.step()
        decoder_optimizer.step()
    
    if epoch % 10 == 0:
        print(get_current_time(), "Epoch %d: %.2f" % (epoch, loss.item()))        
        
    if epoch % 10 == 0:
        model_name = "seq2seq_calculator_v{}.pt".format(epoch)
        torch.save({
            'encoder_model': encoder.state_dict(),
            'decoder_model': decoder.state_dict(),
        }, model_name)
        print("Saving the checkpoint...")

Training...
2020-02-20 14:51:03 Epoch 10: 3.06
Saving the checkpoint...
2020-02-20 14:51:54 Epoch 20: 1.96
Saving the checkpoint...
2020-02-20 14:52:43 Epoch 30: 1.36
Saving the checkpoint...
2020-02-20 14:53:33 Epoch 40: 1.01
Saving the checkpoint...
2020-02-20 14:54:23 Epoch 50: 0.87
Saving the checkpoint...
2020-02-20 14:55:13 Epoch 60: 0.31
Saving the checkpoint...
2020-02-20 14:56:03 Epoch 70: 0.26
Saving the checkpoint...
2020-02-20 14:56:53 Epoch 80: 0.09
Saving the checkpoint...
2020-02-20 14:57:43 Epoch 90: 0.06
Saving the checkpoint...
2020-02-20 14:58:34 Epoch 100: 0.04
Saving the checkpoint...
