Skip to content
Branch: master
Find file History
Latest commit 30580a3 Apr 30, 2018
Permalink
Type Name Latest commit message Commit time
..
Failed to load latest commit information.
data
losses
model_utils
models from six.moves import xrange (en masse) AGAIN Mar 5, 2018
nas_utils
regularization Added new MaskGAN model. Feb 28, 2018
README.md
generate_samples.py
pretrain_mask_gan.py Added new MaskGAN model. Feb 28, 2018
sample_shuffler.py Added new MaskGAN model. Feb 28, 2018
train_mask_gan.py from six.moves import xrange (en masse) AGAIN Mar 5, 2018

README.md

MaskGAN: Better Text Generation via Filling in the ______

Code for MaskGAN: Better Text Generation via Filling in the ______ published at ICLR 2018.

Requirements

  • TensorFlow >= v1.5

Instructions

Warning: The open-source version of this code is still in the process of being tested. Pretraining may not work correctly.

For training on PTB:

  1. Pretrain a LM on PTB and store the checkpoint in /tmp/pretrain-lm/. Instructions WIP.

  2. Run MaskGAN in MLE pretraining mode. If step 1 was not run, set language_model_ckpt_dir to empty.

python train_mask_gan.py \
 --data_dir='/tmp/ptb' \
 --batch_size=20 \
 --sequence_length=20 \
 --base_directory='/tmp/maskGAN' \
 --hparams="gen_rnn_size=650,dis_rnn_size=650,gen_num_layers=2,dis_num_layers=2,gen_learning_rate=0.00074876,dis_learning_rate=5e-4,baseline_decay=0.99,dis_train_iterations=1,gen_learning_rate_decay=0.95" \
 --mode='TRAIN' \
 --max_steps=100000 \
 --language_model_ckpt_dir=/tmp/pretrain-lm/ \
 --generator_model='seq2seq_vd' \
 --discriminator_model='rnn_zaremba' \
 --is_present_rate=0.5 \
 --summaries_every=10 \
 --print_every=250 \
 --max_num_to_print=3 \
 --gen_training_strategy=cross_entropy \
 --seq2seq_share_embedding
  1. Run MaskGAN in GAN mode. If step 2 was not run, set maskgan_ckpt to empty.
python train_mask_gan.py \
 --data_dir='/tmp/ptb' \
 --batch_size=128 \
 --sequence_length=20 \
 --base_directory='/tmp/maskGAN' \
 --mask_strategy=contiguous \
 --maskgan_ckpt='/tmp/maskGAN' \
 --hparams="gen_rnn_size=650,dis_rnn_size=650,gen_num_layers=2,dis_num_layers=2,gen_learning_rate=0.000038877,gen_learning_rate_decay=1.0,gen_full_learning_rate_steps=2000000,gen_vd_keep_prob=0.33971,rl_discount_rate=0.89072,dis_learning_rate=5e-4,baseline_decay=0.99,dis_train_iterations=2,dis_pretrain_learning_rate=0.005,critic_learning_rate=5.1761e-7,dis_vd_keep_prob=0.71940" \
 --mode='TRAIN' \
 --max_steps=100000 \
 --generator_model='seq2seq_vd' \
 --discriminator_model='seq2seq_vd' \
 --is_present_rate=0.5 \
 --summaries_every=250 \
 --print_every=250 \
 --max_num_to_print=3 \
 --gen_training_strategy='reinforce' \
 --seq2seq_share_embedding=true \
 --baseline_method=critic \
 --attention_option=luong
  1. Generate samples:
python generate_samples.py \
 --data_dir /tmp/ptb/ \
 --data_set=ptb \
 --batch_size=256 \
 --sequence_length=20 \
 --base_directory /tmp/imdbsample/ \
 --hparams="gen_rnn_size=650,dis_rnn_size=650,gen_num_layers=2,gen_vd_keep_prob=0.33971" \
 --generator_model=seq2seq_vd \
 --discriminator_model=seq2seq_vd \
 --is_present_rate=0.0 \
 --maskgan_ckpt=/tmp/maskGAN \
 --seq2seq_share_embedding=True \
 --dis_share_embedding=True \
 --attention_option=luong \
 --mask_strategy=contiguous \
 --baseline_method=critic \
 --number_epochs=4

Contact for Issues

You can’t perform that action at this time.
You signed in with another tab or window. Reload to refresh your session. You signed out in another tab or window. Reload to refresh your session.