<a href="https://colab.research.google.com/github/rybread1/trump_speech_writer/blob/master/trump_speech_writer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [8]:
import numpy as np
import tensorflow as tf
import os

In [9]:
!curl -O https://raw.githubusercontent.com/ryanmcdermott/trump-speeches/master/speeches.txt

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0  4  903k    4 42501    0     0   532k      0  0:00:01 --:--:--  0:00:01  532k100  903k  100  903k    0     0  7402k      0 --:--:-- --:--:-- --:--:-- 7342k


In [10]:
## Reading and processing text
with open('speeches.txt', 'r') as fp:
    text = fp.read()
    
start_indx = text.find('Thank you so much')

text = text[start_indx:].lower()  # trimmed text doc
char_set = set(text) # unique character set
char_set_sorted = sorted(char_set)

char_2_int_dict = {ch:i for i,ch in enumerate(char_set_sorted)} # dict mapping char to int
char_array = np.array(char_set_sorted) # array mapping idx to char

text_encoded = np.array(
    [char_2_int_dict[ch] for ch in text],
    dtype=np.int32)

In [17]:
ds_text_encoded = tf.data.Dataset.from_tensor_slices(text_encoded)

seq_length = 40 
chunk_size = seq_length + 1
ds_chunks = ds_text_encoded.batch(chunk_size, drop_remainder=True) 

## define the function for splitting x & y
def split_input_target(chunk):
    input_seq = chunk[:-1]
    target_seq = chunk[1:]
    return input_seq, target_seq

ds_sequences = ds_chunks.map(split_input_target)

# Batch size
BATCH_SIZE = 32
BUFFER_SIZE = 200000

tf.random.set_seed(1)
ds = ds_sequences.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

def get_test_train_split(text, chunk_size, batch_size, train_split=0.9):
    return np.floor(len(text) / chunk_size / batch_size) * train_split

train_batches = get_test_train_split(text_encoded, chunk_size, BATCH_SIZE)

ds_train = ds.take(train_batches)
ds_valid = ds.skip(train_batches)

In [18]:
def build_model(input_size, vocab_size, embedding_dim, rnn_units, dropout=True):
    inputs = tf.keras.Input(input_size)
    x = tf.keras.layers.Embedding(vocab_size, embedding_dim)(inputs)
    x = tf.keras.layers.LSTM(rnn_units, return_sequences=True)(x)
    x = tf.keras.layers.LSTM(rnn_units, return_sequences=True)(x)
    if dropout:
        x = tf.keras.layers.Dropout(0.5)(x)
    outputs = tf.keras.layers.Dense(vocab_size)(x)
    model = tf.keras.Model(inputs, outputs)
    return model

tf.random.set_seed(42)

model = build_model(input_size=seq_length, vocab_size=len(char_array), 
                    embedding_dim=256, rnn_units=512)

model.compile(optimizer='adam', 
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True))

model.summary()

Model: "functional_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_3 (InputLayer)         [(None, 40)]              0         
_________________________________________________________________
embedding_2 (Embedding)      (None, 40, 256)           16896     
_________________________________________________________________
lstm_2 (LSTM)                (None, 40, 512)           1574912   
_________________________________________________________________
lstm_3 (LSTM)                (None, 40, 512)           2099200   
_________________________________________________________________
dropout (Dropout)            (None, 40, 512)           0         
_________________________________________________________________
dense_3 (Dense)              (None, 40, 66)            33858     
Total params: 3,724,866
Trainable params: 3,724,866
Non-trainable params: 0
____________________________________________

In [None]:
checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=0)

results = model.fit(ds_train, validation_data=ds_valid, epochs=20, callbacks=[cp_callback])

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20

In [12]:
def generate_text(model, starting_str, 
           len_generated_text=500, 
           max_input_length=80,
           scale_factor=1.0):
    
    starting_str = starting_str.lower()
    encoded_input = [char_2_int_dict[s] for s in starting_str]
    encoded_input = tf.reshape(encoded_input, (1, -1))

    generated_str = starting_str

    model.reset_states()
    for i in range(len_generated_text):
        logits = model(encoded_input)        
        logits = tf.squeeze(logits, 0)

        scaled_logits = logits * scale_factor
        new_char_indx = tf.random.categorical(scaled_logits, num_samples=1)
        new_char_indx = tf.squeeze(new_char_indx)[-1].numpy()          
        generated_str += str(char_array[new_char_indx])
        
        new_char_indx = tf.expand_dims([new_char_indx], 0)

        encoded_input = tf.concat(
            [encoded_input, new_char_indx],
            axis=1)
        encoded_input = encoded_input[:, -max_input_length:]

    return generated_str

generated_text = generate_text(model, 
                               starting_str='we are going to make america great again!', 
                               scale_factor=3, 
                               len_generated_text=300)

In [13]:
print(generated_text)

we are going to make america great again! we have a very sergeant bergdahl, read the air conditioners is not a politician fighting and i said that in a landslide. and i said, "what about the hell out of the other candidates who i did a great job in free trade. i think it’s going to be a lot of money to be a members makers in the world and 
