In [311]:
import torch
import torch.nn as nn
import random
import glob
import string
import unicodedata
import re


In [523]:
class Lang():
    def __init__(self,name):
        self.name = name
        self.n_words = 2
        self.word2Index = {}
        self.word2Count = {}
        self.index2Word = {0:"SOS",1:"EOS"}
    
    def addSentence(self,sens):
        for s in sens.split(' '):
            self.addWord(s)
    
    def addWord(self, word):
        if word not in self.word2Index:
            self.word2Index[word] = self.n_words
            self.n_words+=1
            self.word2Count[word]=1
            self.index2Word[self.n_words] = word
        else:
            self.word2Count[word]+=1
            

In [524]:
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

# Lowercase, trim, and remove non-letter characters


def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    return s

In [525]:
lines = open('./data/eng-fra.txt', encoding='utf-8').read().strip().split('\n')


In [526]:
pairs = [line.split('\t') for line in lines]

In [532]:
def readLang(lang1, lang2, reverse=False):
    lines = open('./data/eng-fra.txt', encoding='utf-8').read().strip().split('\n')
#     pairs = [line.split('\t') for line in lines]
    pairs = [[normalizeString(k) for k in p] for p in [line.split('\t') for line in lines]]
    input_lang = Lang(lang1)
    output_lang = Lang(lang2)

    if reverse:
        pairs = [list(reversed(ele)) for ele in pairs]
        input_lang = Lang(lang2)
        output_lang = Lang(lang1)
    
    return input_lang, output_lang, pairs
    

In [533]:
# input_lang, output_lang, pairs = readLang('eng','fra',True)

In [534]:
def filterPair(p):
    return len(p[0].split(' ')) < max_length and  len(p[1].split(' ')) < max_length and p[0].startswith(eng_prefixes)
        

In [535]:
def filterPairs(pairs):
    pairs = [p for p in pairs if filterPair(p)]
    return pairs

In [537]:
def preparedData(lang1, lang2, reverse=False):
    input_lang, output_lang, pairs = readLang(lang1,lang2,reverse)
    pairs = filterPairs(pairs)
    for p in pairs:
        input_lang.addSentence(p[0])
        output_lang.addSentence(p[1])
    print("count words")
    print(input_lang.name, input_lang.n_words)
    print(output_lang.name, output_lang.n_words)
    return input_lang, output_lang, pairs

    
    

In [538]:
input_lang, output_lang, pairs = preparedData("eng","fra")

count words
eng 2925
fra 4489


In [539]:
print(random.choice(pairs))

['she is interested in learning new ideas .', 'elle est interessee a apprendre de nouvelles idees .']


In [540]:
pairs[0:5]

[['i m .', 'j ai ans .'],
 ['i m ok .', 'je vais bien .'],
 ['i m ok .', 'ca va .'],
 ['i m fat .', 'je suis gras .'],
 ['i m fat .', 'je suis gros .']]

In [546]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(EncoderRNN, self).__init__()
        
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)
        
    def forward(self, input, hidden):
        embedded = self.embedding(input).view(1,1,-1)
        output = embedded
        output, hidden = self.gru(output, hidden)
        return output, hidden
    
    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size)

In [549]:
class DecoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(DecoderRNN,self).__init__()
        
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.gru = nn.GRU(output_size, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)
        
    def forward(self, input, hidden):
        embedded = self.embedding(input).view(1,1,-1)
        relu = self.relu(embedded)
        output, hidden = self.gru(relu, hidden)
        output = self.out(output[0])
        output = self.softmax(output)
        return output, hidden
        
    
    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size)   
        
        
        

In [550]:
class DecoderRNN1(nn.Module):
    def __init__(self, hidden_size, output_size):
        super(DecoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(output_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input, hidden):
        output = self.embedding(input).view(1, 1, -1)
        output = F.relu(output)
        output, hidden = self.gru(output, hidden)
        output = self.softmax(self.out(output[0]))
        return output, hidden


In [578]:
class attentionRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(attentionRNN,self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        
        
        self.embedding = nn.Embedding(input_size)
        

SyntaxError: unexpected EOF while parsing (<ipython-input-578-c47856204351>, line 1)

In [552]:
def indexesFromSentence(lang, sentence):
    return [lang.word2Index[w] for w in sentence.split(' ')]

def tensorFromSentence(lang, sentence):
    indexes = indexesFromSentence(lang, sentence)
    indexes.append(EOS_token)
    
    return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)

def tensorsFromPair(pair):
    t1= tensorFromSentence(input_lang, pair[0])
    t2 = tensorFromSentence(output_lang, pair[1])
    return t1,t2


In [577]:
rnn = nn.GRU(10, 20, 2)
input = torch.randn(5, 1, 10)
h0 = torch.randn(2, 1, 20)
output, hn = rnn(input, h0)

In [574]:
input.size()


torch.Size([8, 1, 10])

In [571]:
hn.size()

torch.Size([2, 1, 20])

In [572]:
output.size()

torch.Size([5, 1, 20])