In [1]:
import numpy as np
import keras
import matplotlib.pyplot as plt
from gensim.corpora import Dictionary

%matplotlib inline

Using TensorFlow backend.


In [2]:
from s2sutils import *

In [3]:
from scipy.sparse import csc_matrix, csr_matrix

In [4]:
# load models
dictionary = Dictionary.load('chartovec.dict')

model = keras.models.model_from_json(open('s2s.json', 'rb').read())
model.load_weights('s2s.h5')

encoder_model = keras.models.model_from_json(open('s2s_encoder.json', 'rb').read())
encoder_model.load_weights('s2s_encoder.h5')

decoder_model = keras.models.model_from_json(open('s2s_decoder.json', 'rb').read())
decoder_model.load_weights('s2s_decoder.h5')

In [5]:
chartovec_encoder = SentenceToCharVecEncoder(dictionary)

In [6]:
numchars = len(chartovec_encoder.dictionary)
latent_dim = numchars + 20

print numchars
print latent_dim

93
113


In [20]:
def decode_sequence(input_sent, dictionary, maxlen=20):
    num_chars = len(chartovec_encoder.dictionary)
    
    # Encode the input as state vectors.
    input_seq = np.array([chartovec_encoder.encode_sentence(input_sent, endsig=True, maxlen=maxlen).toarray()])
    for i in range(maxlen):
        print(np.argmax(input_seq[0, i, :]))
    states_value = encoder_model.predict(input_seq)
    print(states_value)

    # Generate empty target sequence of length 1.
    target_seq = np.zeros((1, 1, num_chars))
    # Populate the first character of target sequence with the start character.
    target_seq[0, 0, dictionary.token2id['\n']] = 1.

    # Sampling loop for a batch of sequences
    # (to simplify, here we assume a batch of size 1).
    stop_condition = False
    decoded_sentence = ''
    while not stop_condition:
        output_tokens, h, c = decoder_model.predict([target_seq] + states_value)

        # Sample a token
        sampled_token_index = np.argmax(output_tokens[0, -1, :])
        sampled_char = dictionary[sampled_token_index]
        decoded_sentence += sampled_char

        # Exit condition: either hit max length
        # or find stop character.
        if (sampled_char == '\n' or len(decoded_sentence) > maxlen):
            stop_condition = True

        # Update the target sequence (of length 1).
        target_seq = np.zeros((1, 1, num_chars))
        target_seq[0, 0, sampled_token_index] = 1.

        # Update states
        states_value = [h, c]

    return decoded_sentence

In [22]:
decode_sequence('Happy Holiday!', dictionary)

6
30
39
39
32
1
6
22
19
31
12
30
32
53
0
0
0
0
0
0
[array([[ 0.62237471, -0.10841371, -0.10988645,  0.1053278 ,  0.19763544,
        -0.28970626,  0.75328898,  0.9321388 ,  0.        , -0.        ,
         0.        , -0.        , -0.17736164, -0.03423658,  0.17625512,
        -0.52862954, -0.0668629 , -0.90082604,  0.        ,  0.        ,
         0.05833079,  0.9865998 ,  0.        ,  0.        ,  0.        ,
        -0.16179021,  0.        ,  0.        ,  0.        , -0.        ,
        -0.00646369,  0.28666481,  0.30390584,  0.02576808, -0.28487369,
         0.75106221, -0.14938322, -0.        ,  0.13125537,  0.92430687,
         0.01133212, -0.34009448,  0.38465458, -0.95232707, -0.09826176,
        -0.        ,  0.82051015, -0.25243929, -0.1420102 , -0.45547831,
         0.        ,  0.11400329,  0.32928073,  0.09840935, -0.        ,
         0.54697138, -0.1586163 , -0.        ,  0.        , -0.19799042,
         0.28601494,  0.        ,  0.13606215,  0.        , -0.17887902,

u'"I have the countess\n'

In [23]:
decode_sequence('How are you?', dictionary)

6
22
40
1
30
23
13
1
32
22
26
75
0
0
0
0
0
0
0
0
[array([[  5.32119930e-01,  -5.58114052e-02,  -3.55764516e-02,
          2.54105896e-01,   3.33219796e-01,  -3.76831234e-01,
          7.01018274e-01,   8.57998312e-01,   0.00000000e+00,
         -0.00000000e+00,   0.00000000e+00,  -5.96856757e-04,
         -2.07998902e-01,  -2.25872244e-03,   1.69165254e-01,
         -3.41503799e-01,  -5.95183484e-02,  -7.81211257e-01,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          9.95384753e-01,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00,  -1.11100562e-01,   0.00000000e+00,
          0.00000000e+00,   0.00000000e+00,  -0.00000000e+00,
         -0.00000000e+00,   3.83473694e-01,   4.08766091e-01,
          0.00000000e+00,  -1.22126631e-01,   8.92370343e-01,
         -1.44536048e-01,  -0.00000000e+00,   3.44112247e-01,
          9.62628603e-01,   0.00000000e+00,  -7.32973337e-01,
          5.72976530e-01,  -8.78241539e-01,  -1.33580998e-01,
         -0.00000000

u'"I have the countess\n'