In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import re
import unicodedata
import random

In [3]:
SOS_TOKEN = 0
EOS_TOKEN = 1

class Lang():
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {SOS_TOKEN:'SOS', EOS_TOKEN: 'EOS'}
        self.n_words = 2
    
    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.index2word[self.n_words] = word
            self.word2count[word] = 1
            self.n_words += 1
            
        else:
            self.word2count[word] = 1

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

In [4]:
# Turn a Unicode string to plain ASCII, thanks to
# https://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

# Lowercase, trim, and remove non-letter characters


def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    return s

In [5]:
with open('data/eng-fra.txt') as f:
    lines = f.readlines()
    print(lines[:5])
    print(re.match('\n',lines[0]))


['Go.\tVa !\n', 'Run!\tCours\u202f!\n', 'Run!\tCourez\u202f!\n', 'Wow!\tÇa alors\u202f!\n', 'Fire!\tAu feu !\n']
None


In [19]:
class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Encoder, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.rnn = nn.GRU(hidden_size,hidden_size)

    def forward(self, input, hidden):
        embed = self.embedding(input).view(1,1,-1)
        # let's log it first to see the shape
        # need to unroll it with 
        out, hidden = self.rnn(embed, hidden)

        return (out, hidden)
    
    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size)

class Decoder(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.rnn = nn.GRU(hidden_size,hidden_size)

        self.softmax = nn.Softmax()

    def forward(self, input, hidden):
        embed = self.embedding(input).view(1,1,-1)
        # let's log it first to see the shape
        # need to unroll it with view()

        out = F.relu(embed)
        out, hidden = self.rnn(out, hidden)

        out = self.softmax(out)

        return (out, hidden)
    
    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size)

In [8]:
def readLangs(lang1, lang2, reverse=False):
    print("Reading lines...")

    # Read the file and split into lines
    lines = open('data/%s-%s.txt' % (lang1, lang2), encoding='utf-8').\
        read().strip().split('\n')

    # Split every line into pairs and normalize
    pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]

    # Reverse pairs, make Lang instances
    if reverse:
        pairs = [list(reversed(p)) for p in pairs]
        input_lang = Lang(lang2)
        output_lang = Lang(lang1)
    else:
        input_lang = Lang(lang1)
        output_lang = Lang(lang2)

    return input_lang, output_lang, pairs

MAX_LENGTH = 10

eng_prefixes = (
    "i am ", "i m ",
    "he is", "he s ",
    "she is", "she s ",
    "you are", "you re ",
    "we are", "we re ",
    "they are", "they re "
)


def filterPair(p):
    return len(p[0].split(' ')) < MAX_LENGTH and \
        len(p[1].split(' ')) < MAX_LENGTH and \
        p[1].startswith(eng_prefixes)


def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

In [9]:
def prepareData(lang1, lang2, reverse=False):
    input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)
    print("Read %s sentence pairs" % len(pairs))
    pairs = filterPairs(pairs)
    print("Trimmed to %s sentence pairs" % len(pairs))
    print("Counting words...")
    for pair in pairs:
        input_lang.addSentence(pair[0])
        output_lang.addSentence(pair[1])
    print("Counted words:")
    print(input_lang.name, input_lang.n_words)
    print(output_lang.name, output_lang.n_words)
    return input_lang, output_lang, pairs


input_lang, output_lang, pairs = prepareData('eng', 'fra', True)
print(random.choice(pairs))
print(random.choice(pairs))

Reading lines...
Read 135842 sentence pairs
Trimmed to 10599 sentence pairs
Counting words...
Counted words:
fra 4345
eng 2803
['nous ecoutons .', 'we re listening .']
['tu es partiale .', 'you re biased .']


In [10]:
# ideally we hot encode the dataset, so we don't do it at the time of the training loop
dataset = []

def sentenceToIndexes(lang, sentence):
    return [lang.word2index[word] for word in sentence.split(' ')]

def sentenceToTensor(lang, sentence):
    ind = sentenceToIndexes(lang, sentence)
    ind.append(EOS_TOKEN)

    return torch.tensor(ind,dtype=torch.long).view(-1, 1)

for pair in pairs:
    dataset.append((
        sentenceToTensor(input_lang,pair[0]),
        sentenceToTensor(output_lang,pair[1]),
        pair[0],
        pair[1]
    
    ))



print(random.choice(dataset))
print(random.choice(dataset))
    

(tensor([[ 123],
        [2426],
        [ 298],
        [2427],
        [ 101],
        [2428],
        [   5],
        [   1]]), tensor([[  77],
        [  78],
        [ 501],
        [  42],
        [1412],
        [   4],
        [   1]]), 'nous subissons une rafale de vent .', 'we re having a blast .')
(tensor([[ 348],
        [ 297],
        [ 349],
        [ 246],
        [1330],
        [   5],
        [   1]]), tensor([[221],
        [ 78],
        [147],
        [218],
        [  4],
        [  1]]), 'ils ne sont pas morts .', 'they re not dead .')


# Training loop

In [23]:
criterion = nn.NLLLoss() 
hidden_size = 128 
learning_rate = 0.01 
encoder = Encoder(input_lang.n_words, hidden_size) 
decoder = Decoder(output_lang.n_words, hidden_size) 
encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate) 
decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate) 

def trainSentence(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer): 
    encoder_optimizer.zero_grad() 
    decoder_optimizer.zero_grad() 
    
    hidden = encoder.initHidden() 

    for word in input_tensor: 
        # print('enc word: ', word) 
        # print('enc hidden: ', hidden) 
        output, hidden = encoder(word, hidden) 
    
    output_tensors = [] # Teacher forcing by default, let's do couple of runs and then do non teacher forced 
    prev_word =torch.tensor([SOS_TOKEN]) 
    for word in target_tensor: 
        output, hidden = decoder(prev_word, hidden) 
        output_tensors.append(output) 
        prev_word = word 
        print('output: ', output) 
        print('word: ', word) # for each word compute loss ? or just do it in the end # compare word with output 

encoder_optimizer.step()

trainSentence(dataset[100][0], dataset[100][1], encoder, decoder, encoder_optimizer, decoder_optimizer)

enc word:  tensor([6])
enc hidden:  tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]])
enc word:  tensor([88])
enc hidden:  tensor([[[-2.0542e-01, -2.0541e-01, -1.1679e-01, -3.9380e-01,  9.6817e-02,
          -5.9606e-02, -2.1363e-01, -1.0045e-01, -9.4199e-02,  8.5993e-02,
          -4.8529e-01, -1.2816e-03, -3.7617e-01,  1.5554e-01, -1.7137e-01,
           1.2149e-01, -6.5451e-02,  6.3395e-02, -7.0855e-02,  5.4596e-02,
          -3.1501e-01,  9.0228e-02,  1.0014e-01, -7.

  out = self.softmax(out)
