In [0]:
import tensorflow as tf
import numpy as np
import os
import time
import matplotlib.pyplot as plt
tf.enable_eager_execution()

In [2]:
from google.colab import drive

drive.mount('/content/gdrive')
root_path = 'gdrive/My Drive/Colab Notebooks/'
'''
from google.colab import files
uploaded = files.upload()
file = 'shakespeare_input.txt'
'''

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


"\nfrom google.colab import files\nuploaded = files.upload()\nfile = 'shakespeare_input.txt'\n"

In [3]:
#opening file and reading data and making char2id,id2char dictionaries
filepath='/content/gdrive/My Drive/Colab Notebooks/shakespeare_input.txt'
text = open(filepath, 'rb').read().decode(encoding='utf-8')

def clean_doc(text):
	# replace '--' with a space ' '
	text = text.replace('--', ' ')
	# split into tokens by white space
	tokens = text.split(" ")
    
	# remove remaining tokens that are not alphabetic
	#tokens = [word for word in tokens if word.isalpha()]
	# make lower case
	#tokens = [word.lower() for word in tokens]
	return tokens

tokens = clean_doc(text)
words = sorted(list(set(tokens)))
word_indices = dict((w, i) for i, w in enumerate(words))
indices_word = dict((i, w) for i, w in enumerate(words))
tokens = np.array([word_indices[w] for w in tokens])
print(len(words))

138240


In [0]:
seq_length = 20
BUFFER_SIZE = 10000
BATCH_SIZE = 64
dataset = tf.data.Dataset.from_tensor_slices(tokens)
sequences = dataset.batch(seq_length+1, drop_remainder=True)

def split_input_target(c):
    input_ex = c[:-1]
    target = c[1:]
    return input_ex, target

dataset = sequences.map(split_input_target)
dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)

In [5]:
def build_model(vocab_size, embedding_dim, rnn_units, batch_size):
  '''
  This function builds a lstm model, with first layers as
  embedding. Second as lstm and then a dense layer with
  dictionary size output.
  '''
  model = tf.keras.Sequential([
  tf.keras.layers.Embedding(vocab_size, embedding_dim, batch_input_shape=[batch_size, None]),
  tf.keras.layers.LSTM(rnn_units,return_sequences=True,stateful=True,recurrent_initializer='glorot_uniform'),
  tf.keras.layers.Dense(vocab_size)])
  return model

model = build_model(vocab_size = len(words),embedding_dim=256,rnn_units=2048,batch_size=64)

model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding (Embedding)        (64, None, 256)           35389440  
_________________________________________________________________
lstm (LSTM)                  (64, None, 2048)          18882560  
_________________________________________________________________
dense (Dense)                (64, None, 138240)        283253760 
Total params: 337,525,760
Trainable params: 337,525,760
Non-trainable params: 0
_________________________________________________________________


In [0]:
def loss(labels, logits):
  
  '''
  This function returns the cross entropy loss.
  '''
  return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)

adam = tf.keras.optimizers.Adam(lr=0.001, beta_1=0.9, beta_2=0.999, amsgrad=False, clipnorm = 5)
model.compile(optimizer='adam', loss=loss)
checkpoint_dir = '/content/gdrive/My Drive/Colab Notebooks/Checkpoints_wordlevel'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
checkpoint_callback=tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,save_weights_only=True)

In [7]:
#fitting the model and plotting loss curve
history = model.fit(dataset, epochs=6, steps_per_epoch=64, callbacks=[checkpoint_callback])
plt.plot(history.history['loss'])

Epoch 1/6


W0704 06:30:32.390562 140227990177664 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


Epoch 2/6
Epoch 3/6
12/64 [====>.........................] - ETA: 1:18 - loss: 9.7688

KeyboardInterrupt: ignored

In [10]:
def generate_text(model, start_string):
  '''
  This function generates a sequence of characters by taking
  a seed as the start string.
  '''
  num_generate = 100
  input_eval = [word_indices[start_string]]
  
  input_eval = tf.expand_dims(input_eval, 0)
  text_generated = []
  temperature = 1.0
  model.reset_states()
  for i in range(num_generate):
      predictions = model(input_eval)
      predictions = tf.squeeze(predictions, 0)
      predictions = predictions / temperature
      predicted_id = tf.random.categorical(predictions, num_samples=1)[-1,0].numpy()
      #print(predictions)
      input_eval = tf.expand_dims([predicted_id], 0)
      text_generated.append(indices_word[predicted_id])

  return (start_string + ' ' +  ' '.join(text_generated))

tf.train.latest_checkpoint(checkpoint_dir)
model = build_model(vocab_size = len(words),embedding_dim=256,rnn_units=2048,batch_size=1)
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
model.build(tf.TensorShape([1, None]))
print(generate_text(model, start_string="Lord"))

Lord arms.

Second ado.

Second Servingman:
Why suit was ANTONY:
To of for news
I king;
Fear sir.

PAROLLES:
Go you sweat way good from death?
Ah, of is for lordship fairest affection.

LUCIO:
She daughter, suffer observe came thine Philoten: Mowbray,
A bondmen to high point
Of base jest: my from mean?
Love, another have scrolls an well him; swear he my sir, and call'd approaches nay, cousins far
Than England. secret will were, cannot to you me matter?

MENENIUS:
Now, an clouded purchaseth.

URSULA:
Sure, be him is thou so, you I ever came.

KING aught and his so behold. thing is follow the was, yet?

MARK grub,
Time of now done
Like ever me brother,
We less bawds I to father's less


In [11]:
tf.train.latest_checkpoint(checkpoint_dir)

'/content/gdrive/My Drive/Colab Notebooks/Checkpoints_wordlevel/ckpt_2'