In [1]:
from __future__ import print_function

from keras.models import Model
from keras.layers import Input, LSTM, Dense
import numpy as np

Using TensorFlow backend.


In [0]:
batch_size = 64  
epochs = 100  
samples = 10000  
path = 'fra.txt'
latent_dim = 256  


In [0]:
inputtexts = []
targettexts = []
inputchars = set()
targetchars = set()
with open(path, 'r', encoding='utf-8') as f:
    lines = f.read().split('\n')
for line in lines[: min(samples, len(lines) - 1)]:
    inputtext, targettext, _ = line.split('\t')
    
    targettext = '\t' + targettext + '\n'
    inputtexts.append(inputtext)
    targettexts.append(targettext)
    for char in inputtext:
        if char not in inputchars:
            inputchars.add(char)
    for char in targettext:
        if char not in targetchars:
            targetchars.add(char)

In [0]:
inputchars = sorted(list(inputchars))
targetchars = sorted(list(targetchars))
encodertokens = len(inputchars)
decodertokens = len(targetchars)
encoderseqlen = max([len(txt) for txt in inputtexts])
decoderseqlen = max([len(txt) for txt in targettexts])

In [5]:
print('No of samples:', len(inputtexts))
print('No of unique input tokens:', encodertokens)
print('No of unique output tokens:', decodertokens)
print('Max seq len for inputs:', encoderseqlen)
print('Max seq len for outputs:', decoderseqlen)


No of samples: 10000
No of unique input tokens: 71
No of unique output tokens: 93
Max seq len for inputs: 16
Max seq len for outputs: 59


In [0]:
inputtokenindex = dict(
    [(char, i) for i, char in enumerate(inputchars)])
targettokenindex = dict(
    [(char, i) for i, char in enumerate(targetchars)])

In [0]:
encoderinputdata = np.zeros(
    (len(inputtexts), encoderseqlen, encodertokens),
    dtype='float32')
decoderinputdata = np.zeros(
    (len(inputtexts), decoderseqlen, decodertokens),
    dtype='float32')
decodertargetdata = np.zeros(
    (len(inputtexts), decoderseqlen, decodertokens),
    dtype='float32')

In [0]:
for i, (inputtext, targettext) in enumerate(zip(inputtexts, targettexts)):
    for t, char in enumerate(inputtext):
        encoderinputdata[i, t, inputtokenindex[char]] = 1.
    encoderinputdata[i, t + 1:, inputtokenindex[' ']] = 1.
    for t, char in enumerate(targettext):
        decoderinputdata[i, t, targettokenindex[char]] = 1.
        if t > 0:
            
            decodertargetdata[i, t - 1, targettokenindex[char]] = 1.
    decoderinputdata[i, t + 1:, targettokenindex[' ']] = 1.
    decodertargetdata[i, t:, targettokenindex[' ']] = 1.

In [0]:
encoderinputs = Input(shape=(None, encodertokens))
encoder = LSTM(latent_dim, return_state=True)
encoder_outputs, state_h, state_c = encoder(encoderinputs)
encoderstates = [state_h, state_c]

In [0]:
decoderinputs = Input(shape=(None, decodertokens))

decoderlstm = LSTM(latent_dim, return_sequences=True, return_state=True)
decoderoutputs, _, _ = decoderlstm(decoderinputs,
                                     initial_state=encoderstates)
decoderdense = Dense(decodertokens, activation='softmax')
decoderoutputs = decoderdense(decoderoutputs)

In [19]:

model = Model([encoderinputs, decoderinputs], decoderoutputs)

model.compile(optimizer='rmsprop', loss='categorical_crossentropy',
              metrics=['accuracy'])
model.fit([encoderinputdata, decoderinputdata],decodertargetdata, batch_size=batch_size,
          
          epochs=epochs,
          validation_split=0.2)
model.save('s2s.h5')

Train on 8000 samples, validate on 2000 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100

In [0]:
encodermodel = Model(encoderinputs, encoderstates)

In [0]:
decoderstateinputh = Input(shape=(latent_dim,))
decoderstateinputc = Input(shape=(latent_dim,))
decoderstatesinputs = [decoderstateinputh, decoderstateinputc]
decoderoutputs, state_h, state_c = decoderlstm(
    decoderinputs, initial_state=decoderstatesinputs)
decoderstates = [state_h, state_c]
decoderoutputs = decoderdense(decoderoutputs)
decodermodel = Model(
    [decoderinputs] + decoderstatesinputs,
    [decoderoutputs] + decoderstates)


In [0]:

reverseinputcharindex = dict(
    (i, char) for char, i in inputtokenindex.items())
reversetargetcharindex = dict(
    (i, char) for char, i in targettokenindex.items())


In [0]:
def decodesequence(input_seq):
    statesvalue = encodermodel.predict(input_seq)

    targetseq = np.zeros((1, 1, decodertokens))
    targetseq[0, 0, targettokenindex['\t']] = 1.

    stopcondition = False
    decodedsentence = ''
    while not stopcondition:
        outputtokens, h, c = decodermodel.predict(
            [targetseq] + statesvalue)

        sampledtokenindex = np.argmax(outputtokens[0, -1, :])
        sampledchar = reversetargetcharindex[sampledtokenindex]
        decodedsentence += sampledchar


        if (sampledchar == '\n' or
           len(decodedsentence) > decoderseqlen):
            stopcondition = True

        targetseq = np.zeros((1, 1, decodertokens))
        targetseq[0, 0, sampledtokenindex] = 1.

        statesvalue = [h, c]
        
    return decodedsentence


In [29]:
for seq_index in range(100):
   
    input_seq = encoderinputdata[seq_index: seq_index + 1]
    decodedsentence = decodesequence(input_seq)
    print('-')
    print('Input sentence:', inputtexts[seq_index])
    print('Decoded sentence:', decodedsentence)

-
Input sentence: Go.
Decoded sentence: Va !

-
Input sentence: Hi.
Decoded sentence: Salut !

-
Input sentence: Hi.
Decoded sentence: Salut !

-
Input sentence: Run!
Decoded sentence: Cours !

-
Input sentence: Run!
Decoded sentence: Cours !

-
Input sentence: Who?
Decoded sentence: Qui ?

-
Input sentence: Wow!
Decoded sentence: Ça alors !

-
Input sentence: Fire!
Decoded sentence: Au feu !

-
Input sentence: Help!
Decoded sentence: À l'aide !

-
Input sentence: Jump.
Decoded sentence: Saute.

-
Input sentence: Stop!
Decoded sentence: Ça suffit !

-
Input sentence: Stop!
Decoded sentence: Ça suffit !

-
Input sentence: Stop!
Decoded sentence: Ça suffit !

-
Input sentence: Wait!
Decoded sentence: Attends !

-
Input sentence: Wait!
Decoded sentence: Attends !

-
Input sentence: Go on.
Decoded sentence: Poursuivez.

-
Input sentence: Go on.
Decoded sentence: Poursuivez.

-
Input sentence: Go on.
Decoded sentence: Poursuivez.

-
Input sentence: Hello!
Decoded sentence: Bonjour !

-
Inpu