Skip to content

Commit

Permalink
faster help
Browse files Browse the repository at this point in the history
  • Loading branch information
sherjilozair committed May 11, 2018
1 parent c582451 commit 47144fa
Showing 1 changed file with 52 additions and 55 deletions.
107 changes: 52 additions & 55 deletions train.py
@@ -1,68 +1,65 @@
#!/usr/bin/env python

from __future__ import print_function
import tensorflow as tf

import argparse
import time
import os
from six.moves import cPickle

from utils import TextLoader
from model import Model

parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# Data and model checkpoints directories
parser.add_argument('--data_dir', type=str, default='data/tinyshakespeare',
help='data directory containing input.txt with training examples')
parser.add_argument('--save_dir', type=str, default='save',
help='directory to store checkpointed models')
parser.add_argument('--log_dir', type=str, default='logs',
help='directory to store tensorboard logs')
parser.add_argument('--save_every', type=int, default=1000,
help='Save frequency. Number of passes between checkpoints of the model.')
parser.add_argument('--init_from', type=str, default=None,
help="""continue training from saved model at this path (usually "save").
Path must contain files saved by previous training process:
'config.pkl' : configuration;
'chars_vocab.pkl' : vocabulary definitions;
'checkpoint' : paths to model file(s) (created by tf).
Note: this file contains absolute paths, be careful when moving files around;
'model.ckpt-*' : file(s) with model definition (created by tf)
Model params must be the same between multiple runs (model, rnn_size, num_layers and seq_length).
""")
# Model params
parser.add_argument('--model', type=str, default='lstm',
help='lstm, rnn, gru, or nas')
parser.add_argument('--rnn_size', type=int, default=128,
help='size of RNN hidden state')
parser.add_argument('--num_layers', type=int, default=2,
help='number of layers in the RNN')
# Optimization
parser.add_argument('--seq_length', type=int, default=50,
help='RNN sequence length. Number of timesteps to unroll for.')
parser.add_argument('--batch_size', type=int, default=50,
help="""minibatch size. Number of sequences propagated through the network in parallel.
Pick batch-sizes to fully leverage the GPU (e.g. until the memory is filled up)
commonly in the range 10-500.""")
parser.add_argument('--num_epochs', type=int, default=50,
help='number of epochs. Number of full passes through the training examples.')
parser.add_argument('--grad_clip', type=float, default=5.,
help='clip gradients at this value')
parser.add_argument('--learning_rate', type=float, default=0.002,
help='learning rate')
parser.add_argument('--decay_rate', type=float, default=0.97,
help='decay rate for rmsprop')
parser.add_argument('--output_keep_prob', type=float, default=1.0,
help='probability of keeping weights in the hidden layer')
parser.add_argument('--input_keep_prob', type=float, default=1.0,
help='probability of keeping weights in the input layer')
args = parser.parse_args()

def main():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# Data and model checkpoints directories
parser.add_argument('--data_dir', type=str, default='data/tinyshakespeare',
help='data directory containing input.txt with training examples')
parser.add_argument('--save_dir', type=str, default='save',
help='directory to store checkpointed models')
parser.add_argument('--log_dir', type=str, default='logs',
help='directory to store tensorboard logs')
parser.add_argument('--save_every', type=int, default=1000,
help='Save frequency. Number of passes between checkpoints of the model.')
parser.add_argument('--init_from', type=str, default=None,
help="""continue training from saved model at this path (usually "save").
Path must contain files saved by previous training process:
'config.pkl' : configuration;
'chars_vocab.pkl' : vocabulary definitions;
'checkpoint' : paths to model file(s) (created by tf).
Note: this file contains absolute paths, be careful when moving files around;
'model.ckpt-*' : file(s) with model definition (created by tf)
Model params must be the same between multiple runs (model, rnn_size, num_layers and seq_length).
""")
# Model params
parser.add_argument('--model', type=str, default='lstm',
help='lstm, rnn, gru, or nas')
parser.add_argument('--rnn_size', type=int, default=128,
help='size of RNN hidden state')
parser.add_argument('--num_layers', type=int, default=2,
help='number of layers in the RNN')
# Optimization
parser.add_argument('--seq_length', type=int, default=50,
help='RNN sequence length. Number of timesteps to unroll for.')
parser.add_argument('--batch_size', type=int, default=50,
help="""minibatch size. Number of sequences propagated through the network in parallel.
Pick batch-sizes to fully leverage the GPU (e.g. until the memory is filled up)
commonly in the range 10-500.""")
parser.add_argument('--num_epochs', type=int, default=50,
help='number of epochs. Number of full passes through the training examples.')
parser.add_argument('--grad_clip', type=float, default=5.,
help='clip gradients at this value')
parser.add_argument('--learning_rate', type=float, default=0.002,
help='learning rate')
parser.add_argument('--decay_rate', type=float, default=0.97,
help='decay rate for rmsprop')
parser.add_argument('--output_keep_prob', type=float, default=1.0,
help='probability of keeping weights in the hidden layer')
parser.add_argument('--input_keep_prob', type=float, default=1.0,
help='probability of keeping weights in the input layer')
args = parser.parse_args()
train(args)

import tensorflow as tf
from utils import TextLoader
from model import Model

def train(args):
data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length)
Expand Down Expand Up @@ -144,4 +141,4 @@ def train(args):


if __name__ == '__main__':
main()
train(args)

0 comments on commit 47144fa

Please sign in to comment.