<a href="https://colab.research.google.com/github/wooheehee/deeplearning-practice/blob/main/CH8_Char_RNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from __future__ import absolute_import, division, print_function, unicode_literals
from absl import app
import tensorflow as tf
import numpy as np
import os
import time

In [2]:
def split_input_target(chunk):
  input_text = chunk[:-1]
  target_text = chunk[1:]
  
  return input_text, target_text

In [3]:
data_dir = tf.keras.utils.get_file('shakespeare.txt', 'http://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')
batch_size = 64
seq_length = 100
embedding_dim = 256
hidden_size = 1024
num_epochs = 10

Downloading data from http://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt


In [4]:
text = open(data_dir, 'rb').read().decode(encoding='utf-8')
vocab = sorted(set(text))
vocab_size = len(vocab)
print('{} unique characters'.format(vocab_size))
char2idx = {u: i for i, u in enumerate(vocab)}
idx2char = np.array(vocab)

65 unique characters


In [5]:
text_as_int = np.array([char2idx[c] for c in text])

In [6]:
from tensorflow.python.ops.array_ops import sequence_mask
char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int)
sequences = char_dataset.batch(seq_length+1, drop_remainder=True)
dataset = sequences.map(split_input_target)

In [7]:
dataset = dataset.shuffle(10000).batch(batch_size, drop_remainder=True)

In [8]:
class RNN(tf.keras.Model):
  def __init__(self, batch_size):
    super(RNN, self).__init__()
    self.embedding_layer = tf.keras.layers.Embedding(vocab_size, embedding_dim, batch_input_shape=[batch_size, None])
    self.hidden_layer_1 = tf.keras.layers.LSTM(hidden_size, return_sequences=True, stateful=True, recurrent_initializer='glorot_uniform')
    self.output_layer = tf.keras.layers.Dense(vocab_size)

  def call(self, x):
    embedded_input = self.embedding_layer(x)
    features = self.hidden_layer_1(embedded_input)
    logits = self.output_layer(features)

    return logits

In [9]:
def sparse_cross_entropy_loss(labels, logits):
  return tf.reduce_mean(tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True))

In [10]:
optimizer = tf.keras.optimizers.Adam()

In [11]:
@tf.function
def train_step(model, input, target):
  with tf.GradientTape() as tape:
    logits = model(input)
    loss = sparse_cross_entropy_loss(target, logits)
  grads = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(grads, model.trainable_variables))

  return loss

In [12]:
def generate_text(model, start_string):
  num_sampling = 4000

  input_eval = [char2idx[s] for s in start_string]
  input_eval = tf.expand_dims(input_eval, 0)

  text_generated = []

  temperature = 1.0

  model.reset_states()
  for i in range(num_sampling):
    predictions = model(input_eval)
    predictions = tf.squeeze(predictions, 0)

    predictions = predictions / temperature
    predicted_id = tf.random.categorical(predictions, num_samples=1)[-1,0].numpy()

    input_eval = tf.expand_dims([predicted_id], 0)
    text_generated.append(idx2char[predicted_id])

  return (start_string + ''.join(text_generated))

In [13]:
def main(_):
  RNN_model = RNN(batch_size=batch_size)

  for input_example_batch, target_example_batch in dataset.take(1):
    example_batch_predictions = RNN_model(input_example_batch)
    print(example_batch_predictions.shape, "# (batch_size, sequence_length, vocab_size)")

  RNN_model.summary()

  checkpoint_dir = './training_checkpoints'
  checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

  for epoch in range(num_epochs):
    start = time.time()

    hidden = RNN_model.reset_states()

    for (batch_n, (input, target)) in enumerate(dataset):
      loss = train_step(RNN_model, input, target)

      if batch_n % 100 == 0:
        template = 'Epoch {} Batch {} Loss {}'
        print(template.format(epoch+1, batch_n, loss))

    if (epoch+1) % 5 == 0:
      RNN_model.save_weights(checkpoint_prefix.format(epoch=epoch))

    print('Epoch {} Loss {:.4f}'.format(epoch+1, loss))
    print('Time taken for 1 epoch {} sec\n'.format(time.time()-start))

  RNN_model.save_weights(checkpoint_prefix.format(epoch=epoch))
  print("트레이닝 끝")

  sampling_RNN_model = RNN(batch_size=1)
  sampling_RNN_model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
  sampling_RNN_model.build(tf.TensorShape([1,None]))
  sampling_RNN_model.summary()

  print("샘플링 시작")
  print(generate_text(sampling_RNN_model, start_string=u' '))

In [None]:
if __name__ == '__main__':
  app.run(main)