In [1]:
import torch
import torch.nn as nn
import torch.functional as F


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class Lang:
    def __init__(self):
        self.word2index = {}
        self.word2count = {}
        self.index2word = {0:'SOS', 1: 'EOS'}
        self.n_words = 2
    
    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.index2word[self.n_words] = word
            self.word2count[word] = 1
            self.n_words += 1
            
        else:
            self.word2count[word] = 1

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

In [4]:
# Turn a Unicode string to plain ASCII, thanks to
# https://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

# Lowercase, trim, and remove non-letter characters


def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    return s

In [6]:
class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(nn.Module, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.rnn = nn.GRU(input_size,hidden_size)

    def forward(self, input, hidden):
        embed = self.embedding(input)
        # let's log it first to see the shape
        # need to unroll it with view(1,1,-1)
        out, hidden = self.rnn(embed, hidden)

        return (out, hidden)
    
    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size)

class Decoder(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(nn.Module, self).__init__()
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.rnn = nn.GRU(input_size,hidden_size)

        self.softmax = nn.Softmax()

    def forward(self, input, hidden):
        embed = self.embedding(input)
        # let's log it first to see the shape
        # need to unroll it with view()

        out = F.relu(embed)
        out, hidden = self.rnn(out, hidden)

        out = self.softmax(out, dim=1)

        return (out, hidden)
    
    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size)