# train generator

スレタイ生成を行うGeneratorを学習する。

GeneratorはTransformerのエンコーダーをMasked Multi-Head Attentionに変更したネットワーク構成される。
詳細は[scripts/model.py](scripts/model.py)を参照


|パラメータ名|内容|
|:--|:--|
| `num_layers` | transformerのパラメータ、[tensorflowのチュートリアル参照](https://www.tensorflow.org/tutorials/text/transformer) |
| `d_model` | transformerのパラメータ、[tensorflowのチュートリアル参照](https://www.tensorflow.org/tutorials/text/transformer) |
| `dff` | transformerのパラメータ、[tensorflowのチュートリアル参照](https://www.tensorflow.org/tutorials/text/transformer) |
| `num_heads` | transformerのパラメータ、[tensorflowのチュートリアル参照](https://www.tensorflow.org/tutorials/text/transformer) |
| `TEMPERATURE` | 生成するスレッドタイトルの多様性と尤度のトレードオフパラメータ（0.0以上）、0.0のとき最も確信度の高いスレッドのみを生成する ※学習には影響しない|
| `EPOCHS` | 学習のエポック数 |
| `BATCH_SIZE` | 学習のバッチサイズ |

## 入力ファイル

* `dataset.pickle`

## 出力ファイル

* `model/generator/weights_epoch*.h5`

In [1]:
import os
import time
import pickle
import numpy as np
import tensorflow as tf

In [2]:
# モデルパラメータ
num_layers = 4
d_model = 128
dff = 512
num_heads = 8

# 生成パラメータ
TEMPERATURE = 0.85

# 学習パラメータ
EPOCHS = 40
BATCH_SIZE = 128

In [3]:
import sentencepiece as spm
sp = spm.SentencePieceProcessor()
sp.load('sentencepiece.model')

True

In [4]:
with open("real_dataset.pickle", "rb") as f:
    ids = pickle.load(f)
input_tensor = tf.keras.preprocessing.sequence.pad_sequences(ids, padding='post')

In [5]:
# 検証用: 学習データを減らす
# input_tensor = input_tensor[:10000]

In [6]:
# 入力データのパラメータ
vocab_size = sp.get_piece_size()
seq_len = input_tensor.shape[1]

In [7]:
from scripts.model import Generator

In [8]:
generator = Generator(num_layers, d_model, num_heads, dff, vocab_size, max_pos_encoding=seq_len)

In [None]:
# generate with initial model
generation_ids = generator.sample(num_sample=10, temperature=TEMPERATURE, padding=True)

for ids in generation_ids:
    ids_int = list(map(lambda x: int(x), ids))
    print(sp.decode_ids(ids_int))

In [10]:
from sklearn.model_selection import train_test_split
input_tensor_train, input_tensor_valid = train_test_split(input_tensor, test_size=0.1)
print(len(input_tensor_train), len(input_tensor_valid))

1536857 170762


In [11]:
steps_per_epoch_train = len(input_tensor_train)//BATCH_SIZE
steps_per_epoch_valid = len(input_tensor_valid)//BATCH_SIZE

In [12]:
BUFFER_SIZE = len(input_tensor_train)
dataset_train = tf.data.Dataset.from_tensor_slices(input_tensor_train).shuffle(BUFFER_SIZE)
BUFFER_SIZE = len(input_tensor_valid)
dataset_valid = tf.data.Dataset.from_tensor_slices(input_tensor_valid).shuffle(BUFFER_SIZE)

In [13]:
optimizer = tf.keras.optimizers.Adam()

sparse_categorical_cross_entropy = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')


def generator_loss(real, pred, mask):
    # best generation loss
    loss = sparse_categorical_cross_entropy(real, pred)
    loss *= tf.cast(mask, dtype=loss.dtype)
    
    return tf.reduce_mean(loss)

In [14]:
@tf.function
def train_step(inp):    # x: (BATCH_SIZE, seq_len)
    target_input = inp[:, :-1]
    target_real = inp[:, 1:]
    
    with tf.GradientTape() as tape:
        gen_output = generator(target_input, training=True)
        
        mask = tf.math.logical_not(tf.math.equal(target_real, 0))
        
        gen_loss = generator_loss(target_real, gen_output, mask)
        
    gradients = tape.gradient(gen_loss, generator.trainable_variables)
    optimizer.apply_gradients(zip(gradients, generator.trainable_variables))

    return gen_loss

In [15]:
@tf.function
def valid_step(inp):    # x: (BATCH_SIZE, seq_len)    
    target_input = inp[:, :-1]
    target_real = inp[:, 1:]
    
    gen_output = generator(target_input, training=True)

    mask = tf.math.logical_not(tf.math.equal(target_real, 0))

    gen_loss = generator_loss(target_real, gen_output, mask)

    return gen_loss

In [16]:
dataset_train = dataset_train.batch(BATCH_SIZE, drop_remainder=True)
dataset_valid = dataset_valid.batch(BATCH_SIZE, drop_remainder=True)

In [17]:
model_dir = "model/generator"
if not os.path.exists(model_dir):
    os.mkdir(model_dir)

In [None]:
for epoch in range(EPOCHS):
    start = time.time()

    # TRAIN
    total_loss = 0
    for (batch, inp) in enumerate(dataset_train.take(steps_per_epoch_train)):
        # batch_start = time.time()
         
        batch_loss = train_step(inp)
        total_loss += batch_loss
        # print('Time taken for 1 batch {} sec'.format(time.time() - batch_start))

        if batch % 500 == 0:
            print(f'Epoch {epoch+1} Batch {batch} Loss {batch_loss.numpy():.4f}')
            
    generator.save_weights(f"{model_dir}/weights_epoch{epoch+1}.h5")

    print(f'Train Epoch {epoch+1} Gen Loss {total_loss/steps_per_epoch_train:.4f}')
    
    # VALIDATION
    total_valid_loss = 0
    for (batch, inp) in enumerate(dataset_valid.take(steps_per_epoch_valid)):
        batch_loss = valid_step(inp)
        total_valid_loss += batch_loss
        
    print(f'Validation Loss {total_valid_loss/steps_per_epoch_valid:.4f}')

    # GENERATION
    generation_ids = generator.sample(num_sample=10, temperature=TEMPERATURE, padding=True)
    for ids in generation_ids:
        ids_int = list(map(lambda x: int(x), ids))
        print(sp.decode_ids(ids_int))

    print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))