In [1]:
from __future__ import print_function
import os
import random
import numpy as np
import torch
import torch.nn as nn
import itertools

print(torch.__version__)

1.10.2


In [2]:
all_letters = ['<START>', '<STOP>', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '=']
n_letters = len(all_letters)

In [3]:
char_to_int = dict([(b,a) for a,b in enumerate(all_letters)])

In [4]:
# One-hot matrix of first to last letters (START) for input
def input_tensor(line):
    tensor = torch.zeros(1, len(line) + 1, dtype=torch.int32)
    tensor[0, 0] = char_to_int["<START>"]
    for li in range(len(line)):
        letter = line[li]
        tensor[0, li+1] = char_to_int[letter]
    return tensor

# LongTensor of second letter to end (STOP) for target
def target_tensor(line):
    letter_indexes = [char_to_int[line[li]] for li in range(0, len(line))]
    letter_indexes.append(char_to_int["<STOP>"]) 
    return torch.LongTensor([letter_indexes])

In [5]:
def random_training_sample():
    a = random.randint(0, 999)
    b = random.randint(0, 999)

    in_str = str(a) + "+" + str(b) + "=" 
    out_str = str(a + b)

    encode_line_tensor = input_tensor(in_str)
    input_line_tensor = input_tensor(out_str)
    target_line_tensor = target_tensor(out_str)
    return encode_line_tensor, input_line_tensor, target_line_tensor

In [6]:
random_training_sample()

(tensor([[ 0,  4,  4, 11, 12, 10,  4, 10, 13]], dtype=torch.int32),
 tensor([[0, 3, 2, 7, 9]], dtype=torch.int32),
 tensor([[3, 2, 7, 9, 1]]))

In [7]:
class Encoder(torch.nn.Module):
    @property
    def device(self):
        return next(self.parameters()).device
        
        
    def __init__(self,
                 vocab_dim,
                 emb_dim = 10, 
                 hidden_dim = 10,
                 num_layers = 3):
        super(Encoder, self).__init__()
        
        self.emb_dim = emb_dim 
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        # обучаемая матрица эмбедингов
        self.embedding = torch.nn.Embedding(vocab_dim, emb_dim)
        
        self.encoder = torch.nn.GRU(
            emb_dim, hidden_dim, num_layers)
        
    def init_hidden(self):
        return torch.zeros((1, self.num_layers, self.hidden_dim)).to(self.device)
        
    def forward(self, input, h):
        input = self.embedding(input) # shape (batch_size, seq_len, emb_dim)
        
        h = torch.transpose(h, 0, 1)
        
        input = torch.transpose(input, 0, 1) # shape (seq_len, batch_size, emb_dim)
        d, h = self.encoder(input, h)
        return torch.transpose(h, 0, 1)
        

In [49]:
class Decoder(torch.nn.Module):
    @property
    def device(self):
        return next(self.parameters()).device

    def __init__(self,
                 vocab_dim,
                 output_dim,
                 emb_dim = 10, 
                 hidden_dim = 10,
                 num_layers = 1):
        super(Decoder, self).__init__()
        
        self.emb_dim = emb_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.num_layers = num_layers

        self.embedding = torch.nn.Embedding(vocab_dim, self.emb_dim)

        self.decoder = torch.nn.GRU(
            emb_dim, hidden_dim, num_layers)

        self.linear = torch.nn.Linear(hidden_dim, output_dim)

    def init_hidden(self):
        return torch.zeros((1, self.num_layers, self.hidden_dim)).to(self.device)
        
    def forward(self, real, h):
        batch_size = 1
        
        input = self.embedding(real)
        input = torch.transpose(input, 0, 1)
        h = torch.transpose(h, 0, 1)
        d, _ = self.decoder(input, h)
        answers = self.linear(d)
            
        return torch.transpose(answers, 0, 1)
    
    def generate(self, h, max_len=50):
        input = self.embedding(
                    torch.tensor([[char_to_int['<START>']]]).long().to(
                        self.device
                    )
                )
            
        input = torch.transpose(input, 0, 1)
        h = torch.transpose(h, 0, 1)

        answers = torch.zeros(
                (max_len, input.shape[1], self.output_dim)).to(
                    self.device)
                
        result = []
        
        for i in range(max_len):
            d, h = self.decoder(input, h)
            answers[i, :, :] = self.linear(d)[0]
            char = torch.argmax(answers[i:i+1, :, :], dim=-1)
            result.append(char[0, 0].cpu().item())
            input = self.embedding(char)
            
        s = ""
        for r in result:
            if r == char_to_int["<STOP>"]:
                break
            s += all_letters[r]
        return s
        

In [50]:
def train_on_sample(optimizer, loss_function, model, random_training_sample):
    rnn_encoder, rnn_decoder = model
    
    encode_line_tensor, input_line_tensor, target_line_tensor = random_training_sample
    
    #target_line_tensor.unsqueeze_(-1)

    hidden = rnn_encoder.init_hidden()

    hidden = rnn_encoder(encode_line_tensor, hidden)

    output = rnn_decoder(input_line_tensor, hidden)
    
    output = torch.transpose(output, 1, 2)
    
    loss = loss_function(output, target_line_tensor)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return output, loss

In [51]:
rnn_encoder = Encoder(n_letters, 256, 256, 
                      num_layers = 3)

rnn_decoder = Decoder(vocab_dim = n_letters,
                      output_dim = n_letters,
                      emb_dim = 256, 
                      hidden_dim = 256,
                      num_layers = 3)

learning_rate = 0.0001

optimizer = torch.optim.Adam(itertools.chain(rnn_encoder.parameters(),
                                             rnn_decoder.parameters()), 
                             lr=learning_rate)

loss_function = nn.CrossEntropyLoss()

n_iters = 5000
print_every = 50
total_loss = 0 

model = (rnn_encoder, rnn_decoder)

for iter in range(1, n_iters + 1):
    output, loss = train_on_sample(optimizer, loss_function, model, random_training_sample())
    total_loss += loss.cpu().item()
    
    if iter % print_every == 0:
        print('(%d %d%%) %.4f' % (iter, iter / n_iters * 100, total_loss / print_every))
        total_loss = 0


(50 1%) 2.4555
(100 2%) 2.0458
(150 3%) 1.9165


KeyboardInterrupt: 

In [None]:
max_length=30

# Sample from a category and starting letter
def sample(word):
    with torch.no_grad():  # no need to track history in sampling
        encode_line_tensor = input_tensor(word)

        hidden = rnn_encoder.init_hidden()

        hidden = rnn_encoder(encode_line_tensor, hidden)

        output = rnn_decoder.generate(hidden)
        
        return output


In [None]:
sample('2+2=')

In [None]:
sample('144+543=')