In [114]:
from keras.layers import Input, Embedding, SimpleRNN, Dense, merge, Flatten, BatchNormalization, LSTM, TimeDistributed, Dropout
from keras.models import Model
from keras.optimizers import Adam
import urllib2
import numpy as np

dataset_raw = urllib2.urlopen("https://s3.amazonaws.com/text-datasets/nietzsche.txt").read().\
    replace('\n', ' ')

In [115]:
vocab = sorted(list(set([i for i in dataset_raw])))

In [116]:
vocab.insert(0, '\0')

In [117]:
txt_encoder = {v:k for k,v in enumerate(vocab)}
txt_decoder = {k:v for k,v in enumerate(vocab)}

In [118]:
dataset_encoded = [txt_encoder[i] for i in dataset_raw]

In [119]:
vocab_size = len(vocab)

In [120]:
seq_len = 40

In [121]:
train_data = [np.stack([dataset_encoded[i + j] for i in range(0, len(dataset_raw) - seq_len - 1, seq_len)]) for j in range(seq_len)]

In [122]:
output_data = [np.stack([dataset_encoded[i + j]
                         for i in range(0, len(dataset_raw) - seq_len - 1, seq_len)])[:,np.newaxis] 
                         for j in range(1, seq_len + 1)]

In [10]:
inps = []
embs = []

for i in range(seq_len):
    inps.append(Input(shape=(1,), name='inp_%s' % i))
    embs.append(Flatten()(Embedding(input_dim=vocab_size, output_dim=40, name='emb_%s' % i)(inps[i])) )

In [11]:
hidden_layer_size = 256

In [12]:
dense_in = Dense(hidden_layer_size, activation='relu')
dense_hidden = Dense(hidden_layer_size, activation='relu', init='identity')
dense_out = Dense(vocab_size, activation='softmax')

In [13]:
outs = []

zero_inp = Input(shape=(40,), name='zeros')
hidden = dense_in(zero_inp)

for i in range(seq_len):
    bn = BatchNormalization()(embs[i])
    din = dense_in(bn)
    hidden = merge([din, dense_hidden(hidden)])
    outs.append(dense_out(hidden))

In [14]:
zeros = np.tile(np.zeros(40), (len(train_data[0]), 1))

In [None]:
inp = Input(shape=(len(train_data[0]),40))
emb = Embedding(vocab_size, 60, input_length=40)

In [15]:
mdl = Model(input=[zero_inp] + [i for i in inps], output=outs)

In [16]:
mdl.compile(optimizer=Adam(lr=0.00001), loss='sparse_categorical_crossentropy')

In [21]:
mdl.fit([zeros] + train_data, output_data, nb_epoch=12, batch_size=64)

Epoch 1/12
Epoch 2/12
Epoch 3/12
Epoch 4/12
Epoch 5/12
Epoch 6/12
Epoch 7/12
Epoch 8/12
Epoch 9/12
Epoch 10/12
Epoch 11/12
Epoch 12/12


<keras.callbacks.History at 0x7f8672b96590>

In [22]:
mdl.compile(optimizer=Adam(lr=0.0001), loss='sparse_categorical_crossentropy')
mdl.fit([zeros] + train_data, output_data, nb_epoch=12, batch_size=64)

Epoch 1/12
Epoch 2/12
Epoch 3/12
Epoch 4/12
Epoch 5/12
Epoch 6/12
Epoch 7/12
Epoch 8/12
Epoch 9/12
Epoch 10/12
Epoch 11/12
Epoch 12/12


<keras.callbacks.History at 0x7f865e509490>

In [45]:
mdl.compile(optimizer=Adam(lr=0.01), loss='sparse_categorical_crossentropy')
mdl.fit([zeros] + train_data, output_data, nb_epoch=12, batch_size=64)

Epoch 1/12
Epoch 2/12
Epoch 3/12
Epoch 4/12
Epoch 5/12
Epoch 6/12
Epoch 7/12
Epoch 8/12
Epoch 9/12
Epoch 10/12
Epoch 11/12
Epoch 12/12


<keras.callbacks.History at 0x7f865a3911d0>

In [56]:
mdl.compile(optimizer=Adam(lr=0.001), loss='sparse_categorical_crossentropy')
mdl.fit([zeros] + train_data, output_data, nb_epoch=12, batch_size=64)

Epoch 1/12
Epoch 2/12
Epoch 3/12
Epoch 4/12
Epoch 5/12
Epoch 6/12
Epoch 7/12
Epoch 8/12
Epoch 9/12
Epoch 10/12
Epoch 11/12
Epoch 12/12


<keras.callbacks.History at 0x7f865bc48e50>

In [17]:
def mdl_predict(seq_3char):
    if len(seq_3char) < 8:
        padding_len = 8 - len(seq_3char)
        padding = '\0' * padding_len
        seq_3char = padding + seq_3char
    pred_data = [txt_encoder[i] for i in seq_3char]
    arrs = [np.stack([i]) for i in pred_data]
    pred = mdl.predict([np.tile(np.zeros(40), (1, 1))] + arrs)
    return [txt_decoder[np.argmax(o)] for o in pred]

In [18]:
mdl_predict('sufferin')

[')', ';', ';', ';', ';', ';', ';', ';']

In [111]:
def generate_text(num_chars):
    outs = []
    base_str = 'Sufferin'
    for i in range(num_chars):
        prediction = mdl.predict([np.stack(np.zeros(40))[np.newaxis]] +
                                          [np.array([txt_encoder[i]]) for i in base_str])
        next_char = np.argmax(prediction[-1]) # the final model output
        outs.append(txt_decoder[next_char])
        base_str = (base_str + txt_decoder[next_char])[-40:]
    return base_str

In [75]:
generate_text(100)

';S;;S;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;'

In [135]:
inp = Input(batch_shape=(64, seq_len))
emb = Embedding(input_dim=vocab_size, output_dim=40, batch_input_shape=(64,seq_len))(inp)
bn = BatchNormalization()(emb)
rnn = LSTM(output_dim=256, activation='relu', return_sequences=True, stateful=True)(bn)
bn2 = BatchNormalization()(rnn)
d = Dropout(0.2)(bn2)
out = TimeDistributed(Dense(vocab_size, activation='softmax'))(d)

In [136]:
mdl2 = Model(input=inp, output=out)

In [137]:
mdl2.compile(optimizer=Adam(lr=1e-5), loss='sparse_categorical_crossentropy')

In [138]:
x_stateful = np.stack(np.squeeze(train_data), axis=1)[:12800]
y_stateful = np.atleast_3d(np.stack(output_data, axis=1))[:12800]

x_stateful.shape, y_stateful.shape

((12800, 40), (12800, 40, 1))

In [139]:
n_epoch = 10
for i in range(n_epoch):
    mdl2.reset_states()
    mdl2.fit(x_stateful, y_stateful, nb_epoch=1, batch_size=64, shuffle=False)

Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1


In [112]:
mdl2.compile(optimizer=Adam(lr=1e-4), loss='sparse_categorical_crossentropy')
mdl2.fit(x_stateful, y_stateful, nb_epoch=1, batch_size=64, shuffle=False)

Epoch 1/1


<keras.callbacks.History at 0x7fe7b3936550>

In [140]:
def generate_text(num_chars):
    outs = []
    base_str = 'Sufferin'
    for i in range(num_chars):
        prediction = mdl2.predict(np.array([txt_encoder[i] for i in base_str])[np.newaxis])
        next_char = np.argmax(prediction[-1]) # the final model output
        outs.append(txt_decoder[next_char])
        base_str = (base_str + txt_decoder[next_char])[-8:]
    return ''.join(outs)

In [141]:
generate_text(50)

ValueError: Error when checking : expected input_16 to have shape (64, 40) but got array with shape (1, 8)