# Run model module locally

In [1]:
import os

# Import os environment variables for file hyperparameters.
os.environ["TRAIN_FILE_PATTERN"] = "gs://machine-learning-1234-bucket/gan/wgan_gp/data/train*.tfrecord"
os.environ["EVAL_FILE_PATTERN"] = "gs://machine-learning-1234-bucket/gan/wgan_gp/data/test*.tfrecord"
os.environ["OUTPUT_DIR"] = "gs://machine-learning-1234-bucket/gan/wgan_gp/trained_model"

# Import os environment variables for train hyperparameters.
os.environ["TRAIN_BATCH_SIZE"] = str(32)
os.environ["TRAIN_STEPS"] = str(200)

# Import os environment variables for eval hyperparameters.
os.environ["EVAL_BATCH_SIZE"] = str(32)
os.environ["EVAL_STEPS"] = str(100)
os.environ["START_DELAY_SECS"] = str(60)
os.environ["THROTTLE_SECS"] = str(120)

# Import os environment variables for image hyperparameters.
os.environ["HEIGHT"] = str(32)
os.environ["WIDTH"] = str(32)
os.environ["DEPTH"] = str(3)

# Import os environment variables for generator hyperparameters.
os.environ["LATENT_SIZE"] = str(512)
os.environ["GENERATOR_PROJECTION_DIMS"] = "8,8,256"
os.environ["GENERATOR_NUM_FILTERS"] = "128,64"
os.environ["GENERATOR_KERNEL_SIZES"] = "5,5"
os.environ["GENERATOR_STRIDES"] = "1,2"
os.environ["GENERATOR_FINAL_NUM_FILTERS"] = str(3)
os.environ["GENERATOR_FINAL_KERNEL_SIZE"] = str(5)
os.environ["GENERATOR_FINAL_STRIDE"] = str(2)
os.environ["GENERATOR_L1_REGULARIZATION_SCALE"] = str(0.01)
os.environ["GENERATOR_L2_REGULARIZATION_SCALE"] = str(0.01)
os.environ["GENERATOR_OPTIMIZER"] = "Adam"
os.environ["GENERATOR_LEARNING_RATE"] = str(0.00005)
os.environ["GENERATOR_CLIP_GRADIENTS"] = str(5.0)
os.environ["GENERATOR_TRAIN_STEPS"] = str(1)

# Import os environment variables for critic hyperparameters.
os.environ["CRITIC_NUM_FILTERS"] = "64,128"
os.environ["CRITIC_KERNEL_SIZES"] = "5,5"
os.environ["CRITIC_STRIDES"] = "2,2"
os.environ["CRITIC_DROPOUT_RATES"] = "0.3,0.3"
os.environ["CRITIC_L1_REGULARIZATION_SCALE"] = str(0.01)
os.environ["CRITIC_L2_REGULARIZATION_SCALE"] = str(0.01)
os.environ["CRITIC_OPTIMIZER"] = "RMSProp"
os.environ["CRITIC_LEARNING_RATE"] = str(0.00005)
os.environ["CRITIC_CLIP_GRADIENTS"] = str(5.0)
os.environ["CRITIC_GRADIENT_PENALTY_COEFFICIENT"] = str(10.0)
os.environ["CRITIC_TRAIN_STEPS"] = str(5)


## Train WGAN-GP model

In [2]:
%%bash
rm -rf trained_model
export PYTHONPATH=$PYTHONPATH:$PWD/wgan_gp_module
python3 -m trainer.task \
    --train_file_pattern=${TRAIN_FILE_PATTERN} \
    --eval_file_pattern=${EVAL_FILE_PATTERN} \
    --output_dir=${OUTPUT_DIR} \
    --job-dir=./tmp \
    \
    --train_batch_size=${TRAIN_BATCH_SIZE} \
    --train_steps=${TRAIN_STEPS} \
    \
    --eval_batch_size=${EVAL_BATCH_SIZE} \
    --eval_steps=${EVAL_STEPS} \
    --start_delay_secs=${START_DELAY_SECS} \
    --throttle_secs=${THROTTLE_SECS} \
    \
    --height=${HEIGHT} \
    --width=${WIDTH} \
    --depth=${DEPTH} \
    \
    --latent_size=${LATENT_SIZE} \
    --generator_projection_dims=${GENERATOR_PROJECTION_DIMS} \
    --generator_num_filters=${GENERATOR_NUM_FILTERS} \
    --generator_kernel_sizes=${GENERATOR_KERNEL_SIZES} \
    --generator_strides=${GENERATOR_STRIDES} \
    --generator_final_num_filters=${GENERATOR_FINAL_NUM_FILTERS} \
    --generator_final_kernel_size=${GENERATOR_FINAL_KERNEL_SIZE} \
    --generator_final_stride=${GENERATOR_FINAL_STRIDE} \
    --generator_l1_regularization_scale=${GENERATOR_L1_REGULARIZATION_SCALE} \
    --generator_l2_regularization_scale=${GENERATOR_L2_REGULARIZATION_SCALE} \
    --generator_optimizer=${GENERATOR_OPTIMIZER} \
    --generator_learning_rate=${GENERATOR_LEARNING_RATE} \
    --generator_clip_gradients=${GENERATOR_CLIP_GRADIENTS} \
    --generator_train_steps=${GENERATOR_TRAIN_STEPS} \
    \
    --critic_num_filters=${CRITIC_NUM_FILTERS} \
    --critic_kernel_sizes=${CRITIC_KERNEL_SIZES} \
    --critic_strides=${CRITIC_STRIDES} \
    --critic_dropout_rates=${CRITIC_DROPOUT_RATES} \
    --critic_l1_regularization_scale=${CRITIC_L1_REGULARIZATION_SCALE} \
    --critic_l2_regularization_scale=${CRITIC_L2_REGULARIZATION_SCALE} \
    --critic_optimizer=${CRITIC_OPTIMIZER} \
    --critic_learning_rate=${CRITIC_LEARNING_RATE} \
    --critic_clip_gradients=${CRITIC_CLIP_GRADIENTS} \
    --critic_gradient_penalty_coefficient=${CRITIC_GRADIENT_PENALTY_COEFFICIENT} \
    --critic_train_steps=${CRITIC_TRAIN_STEPS}


decode_example: features = {'image_raw': FixedLenFeature(shape=[], dtype=tf.string, default_value=None), 'label': FixedLenFeature(shape=[], dtype=tf.int64, default_value=None)}
decode_example: image = Tensor("DecodeRaw:0", shape=(?,), dtype=uint8)
decode_example: image = Tensor("Reshape:0", shape=(32, 32, 3), dtype=uint8)
decode_example: image = Tensor("sub:0", shape=(32, 32, 3), dtype=float32)
decode_example: label = Tensor("Cast_1:0", shape=(), dtype=int32)

wgan_model: features = {'image': <tf.Tensor 'IteratorGetNext:0' shape=(?, 32, 32, 3) dtype=float32>}
wgan_model: labels = Tensor("IteratorGetNext:1", shape=(?,), dtype=int32, device=/device:CPU:0)
wgan_model: mode = train
wgan_model: params = {'train_file_pattern': 'gs://machine-learning-1234-bucket/gan/wgan_gp/data/train*.tfrecord', 'eval_file_pattern': 'gs://machine-learning-1234-bucket/gan/wgan_gp/data/test*.tfrecord', 'output_dir': 'gs://machine-learning-1234-bucket/gan/wgan_gp/trained_model/', 'train_batch_size': 32, 'train



INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': 'gs://machine-learning-1234-bucket/gan/wgan_gp/trained_model/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f0fd17f7cd0>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tenso