# train_rnn_generator

RNNでスレタイ生成をするGeneratorを学習する。

In [1]:
import os
import time
import pickle
import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split

In [2]:
from utils.text_encoder import NanJThreadTitleEncoder
text_encoder = NanJThreadTitleEncoder.load_from_file("../model/text_encoder")

In [3]:
with open("../model/input.pickle", "rb") as f:
    ids = pickle.load(f)
input_tensor = tf.keras.preprocessing.sequence.pad_sequences(ids, padding='post')

In [4]:
input_tensor_train, input_tensor_valid = train_test_split(input_tensor, test_size=0.1)
print(len(input_tensor_train), len(input_tensor_valid))

1452049 161339


In [5]:
BUFFER_SIZE = len(input_tensor_train)
dataset_train = tf.data.Dataset.from_tensor_slices(input_tensor_train).shuffle(BUFFER_SIZE)
BUFFER_SIZE = len(input_tensor_valid)
dataset_valid = tf.data.Dataset.from_tensor_slices(input_tensor_valid).shuffle(BUFFER_SIZE)

In [6]:
# Model Parameters
vocab_size = text_encoder.vocab_size()
embedding_dim = 128
gen_units = 128
num_stacks = 4
seq_len = input_tensor.shape[1]

In [7]:
from utils.model import create_rnn_generator_model

In [8]:
generator = create_rnn_generator_model(gen_units, embedding_dim, vocab_size, num_stacks)

In [9]:
generator.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 1)]          0                                            
__________________________________________________________________________________________________
embedding (Embedding)           (None, 1, 128)       2560512     input_1[0][0]                    
__________________________________________________________________________________________________
input_2 (InputLayer)            [(None, 128)]        0                                            
__________________________________________________________________________________________________
gru (GRU)                       [(None, 1, 128), (No 99072       embedding[0][0]                  
                                                                 input_2[0][0]                

In [10]:
def generate(num_generations):
    """
    Generatorを使ってnum_generations個の文書（トークンID列）を生成する
    """
    gen_states = []
    for _ in range(num_stacks):
        gen_states.append(tf.zeros((num_generations, gru_units)))
    
    gen_input = tf.expand_dims([text_encoder.SOS_TOKEN_ID] * num_generations, 1)
    generation_ids_list = [gen_input]
    for i in range(seq_len - 1):
        gen_input = generation_ids_list[-1]
        gen_output = generator([gen_input] + gen_states)
        predictions = gen_output[0]
        gen_states = gen_output[1:]
        next_ids = tf.random.categorical(predictions, num_samples=1, dtype="int32")
        generation_ids_list.append(next_ids)
    generation_ids = tf.concat(generation_ids_list, axis=1)
    return generation_ids

In [None]:
# 初期状態
generation_ids = generate(10).numpy()
for ids in generation_ids:
    print("Generation:", " ".join(text_encoder.decode(ids)))

In [12]:
optimizer = tf.keras.optimizers.Adam()

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

In [13]:
@tf.function
def train_step(inp):    # inp: (BATCH_SIZE, seq_len)
    loss = 0
    
    gen_states = []
    for _ in range(num_stacks):
        gen_states.append(tf.zeros((BATCH_SIZE, gen_units)))
    
    with tf.GradientTape() as tape:
        for i in range(inp.shape[1] - 1):
            gen_input = tf.expand_dims(inp[:, i], 1)
            gen_output = generator([gen_input] + gen_states)
            predictions = gen_output[0]
            gen_states = gen_output[1:]
            loss += loss_object(inp[:, i+1], predictions)

    batch_loss = loss / BATCH_SIZE
    
    gradients = tape.gradient(loss, generator.trainable_variables)
    optimizer.apply_gradients(zip(gradients, generator.trainable_variables))

    return batch_loss

In [14]:
@tf.function
def valid_step(inp):
    loss = 0
    
    gen_states = []
    for _ in range(num_stacks):
        gen_states.append(tf.zeros((BATCH_SIZE, gen_units)))
    
    for i in range(inp.shape[1] - 1):
        gen_input = tf.expand_dims(inp[:, i], 1)
        gen_output = generator([gen_input] + gen_states)
        predictions = gen_output[0]
        gen_states = gen_output[1:]
        loss += loss_object(inp[:, i+1], predictions)

    batch_loss = loss / BATCH_SIZE

    return batch_loss

In [15]:
EPOCHS = 20
BATCH_SIZE = 128
steps_per_epoch_train = len(input_tensor_train)//BATCH_SIZE
steps_per_epoch_valid = len(input_tensor_valid)//BATCH_SIZE

In [16]:
dataset_train = dataset_train.batch(BATCH_SIZE, drop_remainder=True)
dataset_valid = dataset_valid.batch(BATCH_SIZE, drop_remainder=True)

In [17]:
if not os.path.exists("../model/rnn_generator"):
    os.mkdir("../model/rnn_generator")

In [None]:
for epoch in range(EPOCHS):
    start = time.time()

    # TRAIN
    total_train_loss = 0
    for (batch, inp) in enumerate(dataset_train.take(steps_per_epoch_train)):
        # batch_start = time.time()
        batch_loss = train_step(inp)
        total_train_loss += batch_loss
        # print('Time taken for 1 batch {} sec'.format(time.time() - batch_start))

        if batch % 500 == 0:
            print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,
                                                         batch,
                                                         batch_loss.numpy()))
    generator.save_weights(f"../model/rnn_generator/weights_epoch{epoch+1}.h5")

    print('Train Epoch {} Loss {:.4f}'.format(epoch + 1, total_train_loss / steps_per_epoch_train))
    
    # VALIDATION
    total_valid_loss = 0
    for (batch, inp) in enumerate(dataset_valid.take(steps_per_epoch_valid)):
        batch_loss = valid_step(inp)
        total_valid_loss += batch_loss
        
    print('Validation Loss {:.4f}'.format(total_valid_loss / steps_per_epoch_valid))
    
    # GENERATION
    generation_ids = generate(10).numpy()
    for ids in generation_ids:
        print("Generation:", " ".join(text_encoder.decode(ids)))
    
    print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))