このノートブックは、以下のノートブックを元に日本語訳、一部章立ての再構成、加筆を行いました。
https://colab.research.google.com/drive/13Vr3PrDg7cc4OZ3W2-grLSVSf0RJYWzb


# Char-RNNのTensorFlow実装

 CharRNNは、Andrej Karpathy氏によって作成された、有名なテキスト生成モデル（文字レベルLSTM）です。任意のテキストで簡単に学習と生成を試すことができます。以下のような、面白い出力例が知られています： 

  * Music: abc notation 音楽のドレミ(ABC)表記
<https://highnoongmt.wordpress.com/2015/05/22/lisls-stis-recurrent-neural-networks-for-folk-music-generation/>,
  * Irish folk music アイルランド伝統音楽
<https://soundcloud.com/seaandsailor/sets/char-rnn-composes-irish-folk-music>-
  * Obama speeches オバマ元大統領のスピーチ
<https://medium.com/@samim/obama-rnn-machine-generated-political-speeches-c8abd18a2ea0>-
  * Eminem lyrics エミネムの歌詞
<https://soundcloud.com/mrchrisjohnson/recurrent-neural-shady>- (NSFW ;-))
  * Research awards 研究のアワード
<http://karpathy.github.io/2015/05/21/rnn-effectiveness/#comment-2073825449>-
  * TED Talks TEDトーク
<https://medium.com/@samim/ted-rnn-machine-generated-ted-talks-3dd682b894c0>-
  * Movie Titles 映画のタイトル <http://www.cs.toronto.edu/~graves/handwriting.html>

このノートブックはTensorFlowでの再実装です。generatorに模倣させたいテキストを入力し、モデルを学習させ、結果を確認し、今後のために学習済みモデルを保存する、という一連の流れを試します。 

各ステップの指示に従って、セルを順番に実行するだけで始められます。アップロードの指示があったら、サイズの大きい（少なくとも1MB以上）テキストファイルが必要です。無くても心配いりません。試しにシェイクスピアの作品から成る準備済みのテキストコーパスを使えます。

学習セルは30秒ごとにチェックポイントを保存します。また、ネットワークの出力をチェックでき、途中から学習を再開することもできます。

## 概要

このノートブックは以下のステップで進めます: 

- データをアップロードする
- ハイパーパラメータを設定する（デフォルト値も使えます） 
- モデル、学習時の損失関数、およびデータ入力マネージャを定義する 
- クラウドGPUを使って、モデルを学習する
- モデルを保存し、それを使って新しいテキストを生成する


RNNの設計は、Andrej Karpathy氏の[char-rnn](https://github.com/karpathy/char-rnn)を基にした[このgithubプロジェクト](https://github.com/sherjilozair/char-rnn-tensorflow)を参考にしました。より詳細を知りたい場合は、Andrej氏の[ブログ投稿](http://karpathy.github.io/2015/05/21/rnn-effectiveness/)を読むことから始めるとよいでしょう。


## A. 環境を準備する

コードの実行に必要なインポートと、定数の定義を行います。


In [None]:
from __future__ import absolute_import, print_function, division
from google.colab import files
from collections import Counter, defaultdict
from copy import deepcopy
from IPython.display import clear_output
from random import randint

import json
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

CHECKPOINT_DIR = './checkpoints/'  #Checkpoints are temporarily kept here.
TEXT_ENCODING = 'utf-8'


## B. データセットを準備する

シェイクスピアの作品群をダウンロードして、学習データセットとして使えます。

自身で準備した任意のテキストファイルを使いたい場合は、それをアップロードし、学習データセットとして使えます。


### 1. データセットをダウンロードする

デフォルトでは、シェイクスピアの作品群をデータセットとして使います。

In [None]:
shakespeare_url = "https://ocw.mit.edu/ans7870/6/6.006/s08/lecturenotes/files/t8.shakespeare.txt"
import urllib
file_contents = urllib.urlopen(shakespeare_url).read()
file_name = "shakespeare"
file_contents = file_contents[10501:]  # Skip headers and start at content
print("An excerpt: \n", file_contents[:664])

An excerpt: 
                      1
  From fairest creatures we desire increase,
  That thereby beauty's rose might never die,
  But as the riper should by time decease,
  His tender heir might bear his memory:
  But thou contracted to thine own bright eyes,
  Feed'st thy light's flame with self-substantial fuel,
  Making a famine where abundance lies,
  Thy self thy foe, to thy sweet self too cruel:
  Thou that art now the world's fresh ornament,
  And only herald to the gaudy spring,
  Within thine own bud buriest thy content,
  And tender churl mak'st waste in niggarding:
    Pity the world, or else this glutton be,
    To eat the world's due, by the grave and thee.


### 2. (オプション）自分で準備したデータセットを使う


自身で用意した学習データセットを使う場合は、次の2つのセルを実行します。使わない場合は、スキップしてください。 


In [None]:
uploaded = files.upload()

In [None]:
if uploaded:
  if type(uploaded) is not dict: uploaded = uploaded.files  ## Deal with filedit versions
  file_bytes = uploaded[uploaded.keys()[0]]
  utf8_string = file_bytes.decode(TEXT_ENCODING)
  file_contents = utf8_string if files else ''
  file_name = uploaded.keys()[0]
print("An excerpt: \n", file_contents[:664])

An excerpt: 
                      1
  From fairest creatures we desire increase,
  That thereby beauty's rose might never die,
  But as the riper should by time decease,
  His tender heir might bear his memory:
  But thou contracted to thine own bright eyes,
  Feed'st thy light's flame with self-substantial fuel,
  Making a famine where abundance lies,
  Thy self thy foe, to thy sweet self too cruel:
  Thou that art now the world's fresh ornament,
  And only herald to the gaudy spring,
  Within thine own bud buriest thy content,
  And tender churl mak'st waste in niggarding:
    Pity the world, or else this glutton be,
    To eat the world's due, by the grave and thee.


## C. データセットを前処理する


まず、プレーンテキストのファイルを、トークンの配列に変換します。そのために、このトークンマッパーヘルパークラスを使用します。 


In [1]:
import string
class TokenMapper(object):
  def __init__(self):
    self.token_mapping = {}
    self.reverse_token_mapping = {}
  def buildFromData(self, utf8_string, limit=0.00004):
    print("Build token dictionary.")
    total_num = len(utf8_string)
    sorted_tokens = sorted(Counter(utf8_string.decode('utf8')).items(), 
                           key=lambda x: -x[1])
    # Filter tokens: Only allow printable characters (not control chars) and
    # limit to ones that are resonably common, i.e. skip strange esoteric 
    # characters in order to reduce the dictionary size.
    filtered_tokens = filter(lambda t: t[0] in string.printable or 
                             float(t[1])/total_num > limit, sorted_tokens)
    tokens, counts = zip(*filtered_tokens)
    self.token_mapping = dict(zip(tokens, range(len(tokens))))
    for c in string.printable:
      if c not in self.token_mapping:
        print("Skipped token for: ", c)
    self.reverse_token_mapping = {
        val: key for key, val in self.token_mapping.items()}
    print("Created dictionary: %d tokens"%len(self.token_mapping))
  
  def mapchar(self, char):
    if char in self.token_mapping:
      return self.token_mapping[char]
    else:
      return self.token_mapping[' ']
  
  def mapstring(self, utf8_string):
    return [self.mapchar(c) for c in utf8_string]
  
  def maptoken(self, token):
    return self.reverse_token_mapping[token]
  
  def maptokens(self, int_array):
    return ''.join([self.reverse_token_mapping[c] for c in int_array])
  
  def size(self):
    return len(self.token_mapping)
  
  def alphabet(self):
    return ''.join([k for k,v in sorted(self.token_mapping.items(),key=itemgetter(1))])

  def print(self):
    for k,v in sorted(self.token_mapping.items(),key=itemgetter(1)): print(k, v)
  
  def save(self, path):
    with open(path, 'wb') as json_file:
      json.dump(self.token_mapping, json_file)
  
  def restore(self, path):
    with open(path, 'r') as json_file:
      self.token_mapping = {}
      self.token_mapping.update(json.load(json_file))
      self.reverse_token_mapping = {val: key for key, val in self.token_mapping.items()}

生の入力をトークンのリストに変換しましょう。

In [2]:
# Clean the checkpoint directory and make a fresh one
!rm -rf {CHECKPOINT_DIR}
!mkdir {CHECKPOINT_DIR}
!ls -lt

chars_in_batch = (sequence_length * batch_size)
file_len = len(file_contents)
unique_sequential_batches = file_len // chars_in_batch

mapper = TokenMapper()
mapper.buildFromData(file_contents)
mapper.save(''.join([CHECKPOINT_DIR, 'token_mapping.json']))

input_values = mapper.mapstring(file_contents)

total 12416
drwxr-xr-x  2 tomo.masuda  staff       64  5  6 15:55 [34m{CHECKPOINT_DIR}[m[m
-rw-r--r--  1 tomo.masuda  staff    49828  5  6 15:53 Char_RNN_ja.ipynb
-rw-r--r--  1 tomo.masuda  staff    14926  5  6 15:46 TF_Hub_Universal_Encoder_ja.ipynb
-rw-r--r--  1 tomo.masuda  staff    23647  5  6 15:43 Text_classification_with_TF_Hub_ja.ipynb
-rw-r--r--  1 tomo.masuda  staff    95773  5  6 12:34 Basic_Text_Classification_ja.ipynb
-rw-r--r--  1 tomo.masuda  staff   108791  5  6 12:27 BigGAN_TF_Hub_Demo_ja.ipynb
-rw-r--r--  1 tomo.masuda  staff    15742  5  6 11:48 Compare_GAN_ja.ipynb
-rwxr-xr-x@ 1 tomo.masuda  staff    13053  5  6 11:42 [31mAction_Recognition_on_the_UCF101_Dataset_ja_only.ipynb[m[m
-rw-r--r--  1 tomo.masuda  staff   777239  5  6 11:41 Deepdream_ja.ipynb
-rw-r--r--  1 tomo.masuda  staff    14102  5  6 11:24 TF_Hub_Delf_module_ja.ipynb
-rw-r--r--  1 tomo.masuda  staff   674878  5  6 11:05 Transfer_Learning_ja.ipynb
-rw-r--r--  1 tomo.masuda  staff    31694  5  6 1

NameError: name 'sequence_length' is not defined


## D. モデルを作成する

### 1. LSTMモデルを作る

まずはじめに、ニューラルネットワークの構造を決める必要があります。次のセルでは、ネットワークを構成するTensorFlowグラフと学習時のハイパーパラメータを含むクラスを作ります。


In [None]:
class RNN(object):
  """Represents a Recurrent Neural Network using LSTM cells.

  Attributes:
    num_layers: The integer number of hidden layers in the RNN.
    state_size: The size of the state in each LSTM cell.
    num_classes: Number of output classes. (E.g. 256 for Extended ASCII).
    batch_size: The number of training sequences to process per step.
    sequence_length: The number of chars in a training sequence.
    batch_index: Index within the dataset to start the next batch at.
    on_gpu_sequences: Generates the training inputs for a single batch.
    on_gpu_targets: Generates the training labels for a single batch.
    input_symbol: Placeholder for a single label for use during inference.
    temperature: Used when sampling outputs. A higher temperature will yield
      more variance; a lower one will produce the most likely outputs. Value
      should be between 0 and 1.
    initial_state: The LSTM State Tuple to initialize the network with. This
      will need to be set to the new_state computed by the network each cycle.
    logits: Unnormalized probability distribution for the next predicted
      label, for each timestep in each sequence.
    output_labels: A [batch_size, 1] int32 tensor containing a predicted
      label for each sequence in a batch. Only generated in infer mode.
  """
  def __init__(self,
               rnn_num_layers=1,
               rnn_state_size=128,
               num_classes=256,
               rnn_batch_size=1,
               rnn_sequence_length=1):
    self.num_layers = rnn_num_layers
    self.state_size = rnn_state_size
    self.num_classes = num_classes
    self.batch_size = rnn_batch_size
    self.sequence_length = rnn_sequence_length
    self.batch_shape = (self.batch_size, self.sequence_length)
    print("Built LSTM: ",
          self.num_layers ,self.state_size ,self.num_classes ,
          self.batch_size ,self.sequence_length ,self.batch_shape)


  def build_training_model(self, dropout_rate, data_to_load):
    """Sets up an RNN model for running a training job.

    Args:
      dropout_rate: The rate at which weights may be forgotten during training.
      data_to_load: A numpy array of containing the training data, with each
        element in data_to_load being an integer representing a label. For
        example, for Extended ASCII, values may be 0 through 255.

    Raises:
      ValueError: If mode is data_to_load is None.
    """
    if data_to_load is None:
      raise ValueError('To continue, you must upload training data.')
    inputs = self._set_up_training_inputs(data_to_load)
    self._build_rnn(inputs, dropout_rate)

  def build_inference_model(self):
    """Sets up an RNN model for generating a sequence element by element.
    """
    self.input_symbol = tf.placeholder(shape=[1, 1], dtype=tf.int32)
    self.temperature = tf.placeholder(shape=(), dtype=tf.float32,
                                      name='temperature')
    self.num_options = tf.placeholder(shape=(), dtype=tf.int32,
                                      name='num_options')
    self._build_rnn(self.input_symbol, 0.0)

    self.temperature_modified_logits = tf.squeeze(
        self.logits, 0) / self.temperature

    #for beam search
    self.normalized_probs = tf.nn.softmax(self.logits)

    self.output_labels = tf.multinomial(self.temperature_modified_logits,
                                        self.num_options)

  def _set_up_training_inputs(self, data):
    self.batch_index = tf.placeholder(shape=(), dtype=tf.int32)
    batch_input_length = self.batch_size * self.sequence_length

    input_window = tf.slice(tf.constant(data, dtype=tf.int32),
                            [self.batch_index],
                            [batch_input_length + 1])

    self.on_gpu_sequences = tf.reshape(
        tf.slice(input_window, [0], [batch_input_length]), self.batch_shape)

    self.on_gpu_targets = tf.reshape(
        tf.slice(input_window, [1], [batch_input_length]), self.batch_shape)

    return self.on_gpu_sequences

  def _build_rnn(self, inputs, dropout_rate):
    """Generates an RNN model using the passed functions.

    Args:
      inputs: int32 Tensor with shape [batch_size, sequence_length] containing
        input labels.
      dropout_rate: A floating point value determining the chance that a weight
        is forgotten during evaluation.
    """
    # Alias some commonly used functions
    dropout_wrapper = tf.contrib.rnn.DropoutWrapper
    lstm_cell = tf.contrib.rnn.LSTMCell
    multi_rnn_cell = tf.contrib.rnn.MultiRNNCell

    self._cell = multi_rnn_cell(
        [dropout_wrapper(lstm_cell(self.state_size), 1.0, 1.0 - dropout_rate)
         for _ in range(self.num_layers)])

    self.initial_state = self._cell.zero_state(self.batch_size, tf.float32)

    embedding = tf.get_variable('embedding',
                                [self.num_classes, self.state_size])

    embedding_input = tf.nn.embedding_lookup(embedding, inputs)
    output, self.new_state = tf.nn.dynamic_rnn(self._cell, embedding_input,
                                               initial_state=self.initial_state)

    self.logits = tf.contrib.layers.fully_connected(output, self.num_classes,
                                                    activation_fn=None)



### 2. ハイパーパラメータを定義する

これらの学習時のハイパーパラメータを決めます。さらに推論時は、テキスト生成用のパラメータを定義します。 

まずはデフォルト値のまま、このセルを実行してください。後ほど、このセルに戻って、パラメータのチューニングを試せます。



In [None]:
num_layers = 2
state_size = 256
batch_size = 64
sequence_length = 256
num_training_steps = 30000 # Takes about 40 minuets 
steps_per_epoch = 500
learning_rate = 0.002
learning_rate_decay = 0.95
gradient_clipping = 5.0


### 3. 損失関数を定義する

誤差(loss)は、ニューラルネットワークモデルが、データ分布をどの程度良くモデリングできているかの尺度です。 

モデルの出力である`logit`と、学習対象の`target`を渡します。この場合、`target_weights`への重み付けですが、このノートブックでは、すべて1として偏りは与えません。 


In [None]:
def get_loss(logits, targets, target_weights):
  with tf.name_scope('loss'):
    return tf.contrib.seq2seq.sequence_loss(
        logits,
        targets,
        target_weights,
        average_across_timesteps=True)


### 4. オプティマイザを定義する

これはTensorFlowに損失を減らしていく、最適化の手法を指定します。一般的な[ADAMアルゴリズム](https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer)を使います。


In [None]:
def get_optimizer(loss, initial_learning_rate, gradient_clipping, global_step,
                  decay_steps, decay_rate):

  with tf.name_scope('optimizer'):
    computed_learning_rate = tf.train.exponential_decay(
        initial_learning_rate,
        global_step,
        decay_steps,
        decay_rate,
        staircase=True)

    optimizer = tf.train.AdamOptimizer(computed_learning_rate)
    trained_vars = tf.trainable_variables()
    gradients, _ = tf.clip_by_global_norm(
        tf.gradients(loss, trained_vars),
        gradient_clipping)
    training_op = optimizer.apply_gradients(
        zip(gradients, trained_vars),
        global_step=global_step)

    return training_op, computed_learning_rate


### 5. 学習進捗確認の関数を定義する

このクラスでは、学習の進行状況を確認できます。


In [None]:
class LossPlotter(object):
  def __init__(self, history_length):
    self.global_steps = []
    self.losses = []
    self.averaged_loss_x = []
    self.averaged_loss_y = []
    self.history_length = history_length

  def draw_plots(self):
    self._update_averages(self.global_steps, self.losses,
                          self.averaged_loss_x, self.averaged_loss_y)

    plt.title('Average Loss Over Time')
    plt.xlabel('Global Step')
    plt.ylabel('Loss')
    plt.plot(self.averaged_loss_x, self.averaged_loss_y, label='Loss/Time (Avg)')
    plt.plot()
    plt.plot(self.global_steps, self.losses,
             label='Loss/Time (Last %d)' % self.history_length,
             alpha=.1, color='r')
    plt.plot()
    plt.legend()
    plt.show()

    plt.title('Loss for the last 100 Steps')
    plt.xlabel('Global Step')
    plt.ylabel('Loss')
    plt.plot(self.global_steps, self.losses,
             label='Loss/Time (Last %d)' % self.history_length, color='r')
    plt.plot()
    plt.legend()
    plt.show()

    # The notebook will be slowed down at the end of training if we plot the
    # entire history of raw data. Plot only the last 100 steps of raw data,
    # and the average of each 100 batches. Don't keep unused data.
    self.global_steps = []
    self.losses = []
    self.learning_rates = []

  def log_step(self, global_step, loss):
    self.global_steps.append(global_step)
    self.losses.append(loss)

  def _update_averages(self, x_list, y_list,
                       averaged_data_x, averaged_data_y):
    averaged_data_x.append(x_list[-1])
    averaged_data_y.append(sum(y_list) / self.history_length)


モデルの学習にはしばらく時間がかかります。コーヒーを飲みながら待っても良いでしょう。学習の30秒ごとに、進捗を失わぬよう、チェックポイントを保存します。トレーニングの進行状況を見るには、時々学習を中止し、推論セルを実行し、学習中のモデルでテキストを生成してみましょう。



### 6. モデルを作成する

モデルを定義し、Tensorflowグラフに学習操作を追加します。 

generatorのテスト後も学習を続けている場合は、次の3つのセルを実行してください。 


In [None]:
tf.reset_default_graph()
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
print('Constructing model...')

model = RNN(
    rnn_num_layers=num_layers,
    rnn_state_size=state_size,
    num_classes=mapper.size(),
    rnn_batch_size=batch_size,
    rnn_sequence_length=sequence_length)

model.build_training_model(0.05, np.asarray(input_values))
print('Constructed model successfully.')

print('Setting up training session...')
neutral_target_weights = tf.constant(
    np.ones(model.batch_shape),
    tf.float32
)
loss = get_loss(model.logits, model.on_gpu_targets, neutral_target_weights)
global_step = tf.get_variable('global_step', shape=(), trainable=False,
                              dtype=tf.int32)
training_step, computed_learning_rate = get_optimizer(
    loss,
    learning_rate,
    gradient_clipping,
    global_step,
    steps_per_epoch,
    learning_rate_decay
)

Constructing model...
Built LSTM:  2 256 84 64 256 (64, 256)
Constructed model successfully.
Setting up training session...



`Supervisor` は、学習の流れとチェックポイントを管理します。 


In [None]:
# Create a supervisor that will checkpoint the model in the CHECKPOINT_DIR
sv = tf.train.Supervisor(
    logdir=CHECKPOINT_DIR,
    global_step=global_step,
    save_model_secs=30)
print('Training session ready.')

Training session ready.


###This next cell will begin the training cycle. 
First, we will attempt to pick up training where we left off, if a previous checkpoint exists, then continue the training process.


## E. モデルを学習させる

この次のセルは学習サイクルを始めます。 以前のチェックポイントが保存されていたら、中断したところから学習を再開しようとします。学習プロセスを続けます。 


In [None]:
from datetime import datetime
start_time = datetime.now()

with sv.managed_session(config=config) as sess:
  print('Training supervisor successfully initialized all variables.')
  if not file_len:
    raise ValueError('To continue, you must upload training data.')
  elif file_len < chars_in_batch:
    raise ValueError('To continue, you must upload a larger set of data.')

  plotter = LossPlotter(100)
  step_number = sess.run(global_step)
  zero_state = sess.run([model.initial_state])
  max_batch_index = (unique_sequential_batches - 1) * chars_in_batch
  while not sv.should_stop() and step_number < num_training_steps:
    feed_dict = {
        model.batch_index: randint(0, max_batch_index),
        model.initial_state: zero_state
        }
    [_, _, training_loss, step_number, current_learning_rate, _] = sess.run(
        [model.on_gpu_sequences,
         model.on_gpu_targets,
         loss,
         global_step,
         computed_learning_rate,
         training_step],
        feed_dict)
    plotter.log_step(step_number, training_loss)
    if step_number % 100 == 0:
      clear_output(True)
      plotter.draw_plots()
      print('Latest checkpoint is: %s' %
            tf.train.latest_checkpoint(CHECKPOINT_DIR))
      print('Learning Rate is: %f' %
            current_learning_rate)

    if step_number % 10 == 0:
      print('global step %d, loss=%f' % (step_number, training_loss))

clear_output(True)

print('Training completed in HH:MM:SS = ', datetime.now()-start_time)
print('Latest checkpoint is: %s' %
      tf.train.latest_checkpoint(CHECKPOINT_DIR))


## F. 学習済みモデルを評価する


### 1. テキストを生成する

いよいよ、テキストを生成してみましょう！ここでは、 **ビームサーチ**アルゴリズムと学習済みモデルを使って、テキストを生成します。ビームサーチは、各ステップで現在の各オプションからN個の次のオプションを候補として挙げます。generatorが「筋の悪い」項目を選択しても、その悪い結果を無視して、より可能性の高い選択肢をとり続けることができます。 

In [None]:
class BeamSearchCandidate(object):
  """Represents a node within the search space during Beam Search.

  Attributes:
    state: The resulting RNN state after the given sequence has been generated.
    sequence: The sequence of selections leading to this node.
    probability: The probability of the sequence occurring, computed as the sum
      of the probabilty of each character in the sequence at its respective
      step.
  """

  def __init__(self, init_state, sequence, probability):
    self.state = init_state
    self.sequence = sequence
    self.probability = probability

  def search_from(self, tf_sess, rnn_model, temperature, num_options):
    """Expands the num_options most likely next elements in the sequence.

    Args:
      tf_sess: The Tensorflow session containing the rnn_model.
      rnn_model: The RNN to use to generate the next element in the sequence.
      temperature: Modifies the probabilities of each character, placing
        more emphasis on higher probabilities as the value approaches 0.
      num_options: How many potential next options to expand from this one.

    Returns: A list of BeamSearchCandidate objects descended from this node.
    """
    expanded_set = []
    feed = {rnn_model.input_symbol: np.array([[self.sequence[-1]]]),
            rnn_model.initial_state: self.state,
            rnn_model.temperature: temperature,
            rnn_model.num_options: num_options}
    [predictions, probabilities, new_state] = tf_sess.run(
        [rnn_model.output_labels,
         rnn_model.normalized_probs,
         rnn_model.new_state], feed)
    # Get the indices of the num_beams next picks
    picks = [predictions[0][x] for x in range(len(predictions[0]))]
    for new_char in picks:
      new_seq = deepcopy(self.sequence)
      new_seq.append(new_char)
      expanded_set.append(
          BeamSearchCandidate(new_state, new_seq,
                              probabilities[0][0][new_char] + self.probability))
    return expanded_set

  def __eq__(self, other):
    return self.sequence == other.sequence

  def __ne__(self, other):
    return not self.__eq__(other)

  def __hash__(self):
    return hash(self.sequence())

In [None]:
def beam_search_generate_sequence(tf_sess, rnn_model, primer, temperature=0.85,
                                  termination_condition=None, num_beams=5):
  """Implements a sequence generator using Beam Search.

  Args:
    tf_sess: The Tensorflow session containing the rnn_model.
    rnn_model: The RNN to use to generate the next element in the sequence.
    temperature: Controls how 'Creative' the generated sequence is. Values
      close to 0 tend to generate the most likely sequence, while values
      closer to 1 generate more original sequences. Acceptable values are
      within (0, 1].
    termination_condition: A function taking one parameter, a list of
      integers, that returns True when a condition is met that signals to the
      RNN to return what it has generated so far.
    num_beams: The number of possible sequences to keep at each step of the
      generation process.

  Returns: A list of at most num_beams BeamSearchCandidate objects.
  """
  candidates = []

  rnn_current_state = sess.run([rnn_model.initial_state])
  #Initialize the state for the primer
  for primer_val in primer[:-1]:
    feed = {rnn_model.input_symbol: np.array([[primer_val]]),
            rnn_model.initial_state: rnn_current_state
           }
    [rnn_current_state] = tf_sess.run([rnn_model.new_state], feed)

  candidates.append(BeamSearchCandidate(rnn_current_state, primer, num_beams))

  while True not in [termination_condition(x.sequence) for x in candidates]:
    new_candidates = []
    for candidate in candidates:
      expanded_candidates = candidate.search_from(
          tf_sess, rnn_model, temperature, num_beams)
      for new in expanded_candidates:
        if new not in new_candidates:
          #do not reevaluate duplicates
          new_candidates.append(new)
    candidates = sorted(new_candidates,
                        key=lambda x: x.probability, reverse=True)[:num_beams]

  return [c for c in candidates if termination_condition(c.sequence)]

テキスト生成を開始させる最初の文字列と、テキストの長さを指定します。 

 「創造性(Creativity)」は、モデルがパターンの一致にどれほど重点を置くかを示しています。出力に繰り返しが見られる場合は、この値を増やしてください。出力がランダムすぎると思ったら場合は、この値を下げてみてください。 

結果があまり一般的には見えない場合は、3つの学習セルをもう少し長く実行します。誤差(loss)が下がれば下がるほど、生成テキストは、より厳密に学習データセットに近い結果となるはずです。

In [None]:
tf.reset_default_graph()
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.InteractiveSession(config=config)

model = RNN(
    rnn_num_layers=num_layers,
    rnn_state_size=state_size,
    num_classes=mapper.size(),
    rnn_batch_size=1,
    rnn_sequence_length=1)

model.build_inference_model()

sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(tf.global_variables())
ckpt = tf.train.latest_checkpoint(CHECKPOINT_DIR)
saver.restore(sess, ckpt)

def gen(start_with, pred, creativity):
  int_array = mapper.mapstring(start_with)
  candidates = beam_search_generate_sequence(
      sess, model, int_array, temperature=creativity,
      termination_condition=pred,
      num_beams=1)
  gentext = mapper.maptokens(candidates[0].sequence)
  return gentext

def lengthlimit(n):
  return lambda text: len(text)>n
def sentences(n):
  return lambda text: mapper.maptokens(text).count(".")>=n
def paragraph():
  return lambda text: mapper.maptokens(text).count("\n")>0



Built LSTM:  2 256 84 1 1 (1, 1)
INFO:tensorflow:Restoring parameters from ./checkpoints/model.ckpt-7270


In [None]:
length_of_generated_text = 2000
creativity = 0.85  # Should be greater than 0 but less than 1

print(gen("  ANTONIO: Who is it ?", lengthlimit(length_of_generated_text), creativity))

  ANTONIO: Who is it ?IE. Do piy min, wher by till blingestn.
    Mave in wind for ne cient hafesteres one for yor your yaev yould and londond, afrely,,
    At that Wood tho your you, waecth wing Cichbry your,
    I theer mien, ward the me a see hen thy yould berliiver to her mesting of rive ore yout,
     Thoush F'er in helr all hive und so bing the nost atHer with stourss the madss'
    Kom no, you shat with thing in shee and wear lat yom horch.
    Aqhat wirte you dose ou? 'con acrerante,
    Fow his nast sthe, lost the sarthind.
    Read und is liubk.                                                                         Anmen swenely they for musmer, with the woruster. Aurles I I hell with she laikn.
    Mush me  tid. Why will nle bes by lothile Vith is entern-that recume thee, anlind on whou to rudes
    Bust a hath tull a pnane form gowet proncud, here and thou theat sam,
    Tull at the me then with you sowe braen wune
    Susralshes to gresheneuy, and tleacenss
    thall a th


### 2. 学習済みのモデルを保存する

学習済みのRNNのコピーを保存しおき、後で使えるようにしておきましょう。


In [None]:
save_model_to_drive = False  ## Set this to true to save directly to Google Drive.

def save_model_hyperparameters(path):
  with open(path, 'w')  as json_file:
    model_params = {
        'num_layers': model.num_layers,
        'state_size': model.state_size,
        'num_classes': model.num_classes
    }
    json.dump(model_params, json_file)

def save_to_drive(title, content):
  # Install the PyDrive wrapper & import libraries.
  !pip install -U -q PyDrive
  from pydrive.auth import GoogleAuth
  from pydrive.drive import GoogleDrive
  from google.colab import auth
  from oauth2client.client import GoogleCredentials

  # Authenticate and create the PyDrive client.
  auth.authenticate_user()
  gauth = GoogleAuth()
  gauth.credentials = GoogleCredentials.get_application_default()
  drive = GoogleDrive(gauth)

  newfile = drive.CreateFile({'title': title})
  newfile.SetContentFile(content)
  newfile.Upload()
  print('Uploaded file with ID %s as %s'% (newfile.get('id'),
         archive_name))
    
archive_name = ''.join([file_name,'_seedbank_char-rnn.zip'])
latest_model = tf.train.latest_checkpoint(CHECKPOINT_DIR).split('/')[2]
checkpoints_archive_path = ''.join(['./exports/',archive_name])
if not latest_model:
  raise ValueError('You must train a model before you can export one.')
  
%system mkdir exports
%rm -f {checkpoints_archive_path}
mapper.save(''.join([CHECKPOINT_DIR, 'token_mapping.json']))
save_model_hyperparameters(''.join([CHECKPOINT_DIR, 'model_attributes.json']))
%system zip '{checkpoints_archive_path}' -@ '{CHECKPOINT_DIR}checkpoint' \
            '{CHECKPOINT_DIR}token_mapping.json' \
            '{CHECKPOINT_DIR}model_attributes.json' \
            '{CHECKPOINT_DIR}{latest_model}.'*

if save_model_to_drive:
  save_to_drive(archive_name, checkpoints_archive_path)
else:
  files.download(checkpoints_archive_path)



Uploaded file with ID 1dQEv67yQe10ccsW13sJx89ilhCDmXzxP as shakespeare_seedbank_char-rnn.zip
