In [1]:
from __future__ import print_function
import tensorflow as tf
from preeminence_utils import tf_utils
import numpy as np
import random
import os

In [2]:
def get_latest_epoch():
    models = os.listdir("./model/model_southpark/")
    all_epochs = [int(model[model.find("-")+1:].replace(".meta","")) for model in models if "meta" in model]
    return max(all_epochs)

In [3]:
# initialise text variables
data_file = "./data/south_park/all_scripts.txt"
text = open(data_file).read().strip()
vocab = sorted(list(set(text)))
vocab_length = len(vocab)
characters2id = dict((c, i) for i, c in enumerate(vocab))
id2characters = dict((i, c) for i, c in enumerate(vocab))
section_length = 50
step = 10
sections = []
section_labels = []
for i in range(0,len(text)-section_length,step):
    sections.append(text[i:i+section_length])
    section_labels.append(text[i+section_length])

X_data = np.zeros((len(sections),section_length,vocab_length))
Y_data = np.zeros((len(sections),vocab_length))
for i,section in enumerate(sections):
    for j,letter in enumerate(section):
        X_data[i,j,characters2id[letter]] = 1
    Y_data[i,characters2id[section_labels[i]]] = 1

print(X_data.shape,Y_data.shape)


(260539, 50, 111) (260539, 111)


In [4]:
model = tf_utils.Model()
model_graph = model.init().as_default()

In [5]:
learning_rate = 0.01
total_epochs = 500
batch_size = 128
log_every = 100
save_every = 10
hidden_nodes = 1024

In [6]:
X = tf.placeholder(tf.float32,[None,section_length,vocab_length],name="X_train")
Y = tf.placeholder(tf.float32,[None,vocab_length],name="Y_train")

W = tf.Variable(tf.random_normal([hidden_nodes,vocab_length]),name="Output_weights")
b = tf.Variable(tf.random_normal([vocab_length]),name="Output_bias")

In [7]:
def lstm(x,weights,bias,name_scope="lstm"):
    with tf.name_scope(name_scope):
        x = tf.unstack(x,section_length,1)
        lstm_cell = tf.contrib.rnn.BasicLSTMCell(hidden_nodes,forget_bias=1.0)
        outputs, states = tf.contrib.rnn.static_rnn(lstm_cell, x, dtype=tf.float32)
        return tf.matmul(outputs[-1],W)+b

In [8]:
logits = lstm(X,W,b)
prediction = tf.nn.softmax(logits)
loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=Y))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
train_op = optimizer.minimize(loss_op)

correct_pred = tf.equal(tf.argmax(prediction, 1), tf.argmax(Y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

init = tf.global_variables_initializer()
# saver = tf.train.Saver(max_to_keep=0)


In [None]:
# Training
epoch_start = get_latest_epoch()
print("Resuming training from epoch: {}".format(epoch_start))
with model.session() as sess:
#     sess.run(init)
    model.restore_weights("./model/model_southpark/")
    for i in range(1):
        ops = model.train([train_op,loss_op],X,Y,X_data,Y_data,num_epochs=1,batch_size=batch_size)
        model.save_weights(checkpoint_path="./model/model_southpark/")
#     saver.restore(sess,"./model/model_southpark/model.ckpt-"+str(epoch_start))
#     for epoch in range(total_epochs):
#         print("Epoch: {}".format(epoch) )
#         for batch_i in range(len(X_data)/batch_size):
#             batch_X = X_data[batch_i*batch_size:(batch_i+1)*batch_size]
#             batch_Y = Y_data[batch_i*batch_size:(batch_i+1)*batch_size]
#             sess.run(train_op,feed_dict={X:batch_X,Y:batch_Y})
            
#             if batch_i%log_every == 0:
#                 loss, acc = sess.run([loss_op, accuracy], feed_dict={X: batch_X,Y: batch_Y})
#                 print("Step " + str(batch_i) + ", Minibatch Loss= " + \
#                       "{:.4f}".format(loss) + ", Training Accuracy= " + \
#                       "{:.3f}".format(acc))
#         if epoch%save_every == 0:
#             save_path = saver.save(sess, "./model/model.ckpt",global_step=epoch+epoch_start)
#             print("Model saved in file: %s" % save_path)

In [10]:
def sample2(preds, temperature=1.0):
    preds = np.asarray(preds).astype('float64')
    preds = np.log(preds) / temperature
    exp_preds = np.exp(preds)
    preds = exp_preds / np.sum(exp_preds)
    probas = np.random.multinomial(1, preds, 1)
    return probas

In [11]:
# Testing
# test_start = "STAN I said 'We're not getting on, you fat ugly bitch'."[:50]
# section = [test_start]
prediction_length = 500
epoch_test = get_latest_epoch()
print("Testing epoch: {}".format(epoch_test))
#


with model.session() as sess:
    model.restore_weights("./model/model_southpark/")
    start_index = random.randint(0, len(text) - section_length - 1)
    test_start = text[start_index: start_index + section_length]
    
    X_test = np.zeros((1,section_length,vocab_length))
    for i,c in enumerate(test_start):
        X_test[0,i,characters2id[c]]=1
    
    for temperature in [0.2]:
        print("Temperature = {}".format(temperature))
        
        print(test_start,end="")
        for _ in range(prediction_length): 
            pred = sess.run(prediction,feed_dict={X:X_test})
            pred = pred.reshape(-1)
            pred = sample2(pred,temperature)
            next_char = id2characters[np.argmax(pred)]
#             print(np.argmax(pred),max(pred))
#             print(pred)
#             break
            test_start += next_char
            print(next_char,end="")
            
            test_start = test_start[-section_length:]
            X_test = np.zeros((1,section_length,vocab_length))
            for i,c in enumerate(test_start):
                X_test[0,i,characters2id[c]]=1


Testing epoch: 310
INFO:tensorflow:Restoring parameters from ./model/model_southpark/model_weights.ckpt-1
Temperature = 0.2
s hurt Wall*Mart!

CARTMAN
That's not what I said!

ARARSTAR
Ah we'll we bat the dodd gonn!

SSSSTMAN
I'm get the wwar The popent wirk ow
Mabbere�

[Oh. CARKEL
AR KEN
He you't a toud hoob: the frien dee're
stant to arrating the wirn of and watchen
the fire and ard parted and the athent
here the+++lly, and the fare and and K
Phopp *00 Z00000..."

The brivif it. The day To Bug thet Comk arommalick
putter frinting to the reare bexthre
paips, and then me an fare and and here and Zurk. He phoolis
pare and rench]

KYLE
Hey, the kir. Werry haie, Kyle,