# generate_sample



In [1]:
import os
import time
import pickle
import numpy as np
import tensorflow as tf

In [2]:
from utils.text_encoder import NanJThreadTitleEncoder
text_encoder = NanJThreadTitleEncoder.load_from_file("../model/text_encoder")

In [3]:
with open("../model/input.pickle", "rb") as f:
    ids = pickle.load(f)
input_tensor = tf.keras.preprocessing.sequence.pad_sequences(ids, padding='post')

In [4]:
# Model Parameters
vocab_size = text_encoder.vocab_size()
embedding_dim = 128
gen_units = 128
gru_units = 128
num_stacks = 4
seq_len = input_tensor.shape[1]

In [5]:
from utils.model import create_rnn_generator_model, create_generation_evaluator_model

In [6]:
generator_epoch = 16
generator = create_rnn_generator_model(gen_units, embedding_dim, vocab_size, num_stacks)
generator.load_weights(f'../model/rnn_generator/weights_epoch{generator_epoch}.h5')
generator.trainable = False

embedding_layer = generator.get_layer(name="embedding")
embedding_layer.trainable = False

evaluator_epoch = 10
evaluator = create_generation_evaluator_model(embedding_layer, gru_units, embedding_dim, seq_len, vocab_size)
evaluator.load_weights(f'../model/evaluator/weights_epoch{evaluator_epoch}.h5')
evaluator.trainable = False

In [7]:
generator.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 1)]          0                                            
__________________________________________________________________________________________________
embedding (Embedding)           (None, 1, 128)       2560512     input_1[0][0]                    
__________________________________________________________________________________________________
input_2 (InputLayer)            [(None, 128)]        0                                            
__________________________________________________________________________________________________
gru (GRU)                       [(None, 1, 128), (No 99072       embedding[0][0]                  
                                                                 input_2[0][0]                

In [8]:
evaluator.summary()

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_6 (InputLayer)         [(None, 24)]              0         
_________________________________________________________________
embedding_1 (Embedding)      (None, 24, 128)           2560512   
_________________________________________________________________
bidirectional (Bidirectional (None, 24, 256)           198144    
_________________________________________________________________
bidirectional_1 (Bidirection (None, 256)               296448    
_________________________________________________________________
dense_1 (Dense)              (None, 128)               32896     
_________________________________________________________________
batch_normalization (BatchNo (None, 128)               512       
_________________________________________________________________
activation (Activation)      (None, 128)               0   

In [9]:
def generate(num_generations, prefix=None):
    """
    Generatorを使ってnum_generations個の文書（トークンID列）を生成する
    prefixにトークンID列を指定することで、特定のPrefixのスレタイを生成する
    """
    gen_states = []
    for _ in range(num_stacks):
        gen_states.append(tf.zeros((num_generations, gru_units)))
    
    gen_input = tf.expand_dims([text_encoder.SOS_TOKEN_ID] * num_generations, 1)
    generation_ids_list = [gen_input]
    if prefix is not None:
        for t in prefix:
            gen_input = generation_ids_list[-1]
            gen_output = generator([gen_input] + gen_states)
            gen_states = gen_output[1:]
            # generation_ids_listをprefixで埋める
            generation_ids_list.append(tf.expand_dims([t] * num_generations, 1))     
        
    for i in range(seq_len - len(generation_ids_list)):
        gen_input = generation_ids_list[-1]
        gen_output = generator([gen_input] + gen_states)
        predictions = gen_output[0]
        gen_states = gen_output[1:]
        next_ids = tf.random.categorical(predictions, num_samples=1, dtype="int32")
        generation_ids_list.append(next_ids)
    generation_ids = tf.concat(generation_ids_list, axis=1)
    return generation_ids

# Evaluatorあり

In [None]:
predicted_ids = generate(5000).numpy()
scores = evaluator(predicted_ids).numpy().flatten()
for ids, s in sorted(zip(predicted_ids, scores), key=lambda x: -x[1])[:20]:
    print("Generation:", " ".join(text_encoder.decode(ids)), "score:", s)

In [None]:
prefix = text_encoder.encode(["三大"])
predicted_ids = generate(5000, prefix=prefix).numpy()
scores = evaluator(predicted_ids).numpy().flatten()
for ids, s in sorted(zip(predicted_ids, scores), key=lambda x: -x[1])[:20]:
    print("Generation:", " ".join(text_encoder.decode(ids)), "score:", s)

In [None]:
prefix = text_encoder.encode(["【", "なぞなぞ", "】"])
predicted_ids = generate(5000, prefix=prefix).numpy()
scores = evaluator(predicted_ids).numpy().flatten()
for ids, s in sorted(zip(predicted_ids, scores), key=lambda x: -x[1])[:20]:
    print("Generation:", " ".join(text_encoder.decode(ids)), "score:", s)

# Evaluatorなし

In [None]:
predicted_ids = generate(20).numpy()
for ids in predicted_ids:
    print("Generation:", " ".join(text_encoder.decode(ids)))

In [None]:
prefix = text_encoder.encode(["三大"])
predicted_ids = generate(20, prefix=prefix).numpy()
for ids in predicted_ids:
    print("Generation:", " ".join(text_encoder.decode(ids)))

In [None]:
# generation
prefix = text_encoder.encode(["【", "なぞなぞ", "】"])
predicted_ids = generate(20, prefix=prefix).numpy()
for ids in predicted_ids:
    print("Generation:", " ".join(text_encoder.decode(ids)))

In [19]:
# memo