@@ -16,12 +16,12 @@
logging .basicConfig (level = logging .INFO )
tf .app .flags .DEFINE_float ("learning_rate" , 0.01 , "Learning rate." )
tf .app .flags .DEFINE_float ("max_gradient_norm" , 10.0 , "Clip gradients to this norm." )
tf .app .flags .DEFINE_float ("max_gradient_norm" , 8 , "Clip gradients to this norm." )
tf .app .flags .DEFINE_float ("dropout" , 0.15 , "Fraction of units randomly dropped on non-recurrent connections." )
tf .app .flags .DEFINE_integer ("batch_size" , 100 , "Batch size to use during training." )
tf .app .flags .DEFINE_integer ("batch_size" , 50 , "Batch size to use during training." )
tf .app .flags .DEFINE_integer ("epochs" , 10 , "Number of epochs to train." )
tf .app .flags .DEFINE_integer ("state_size" , 50 , "Size of each model layer." )
tf .app .flags .DEFINE_integer ("output_size" , 750 , "The output size of your model." )
tf .app .flags .DEFINE_integer ("state_size" , 100 , "Size of each model layer." )
tf .app .flags .DEFINE_integer ("output_size" , 766 , "The output size of your model." )
tf .app .flags .DEFINE_integer ("embedding_size" , 100 , "Size of the pretrained vocabulary." )
tf .app .flags .DEFINE_string ("data_dir" , "data/squad" , "SQuAD directory (default ./data/squad)" )
tf .app .flags .DEFINE_string ("train_dir" , "train" , "Training directory to save the model parameters (default: ./train)." )
@@ -34,9 +34,9 @@
tf .app .flags .DEFINE_string ("embed_path" , "" , "Path to the trimmed GLoVe embedding (default: ./data/squad/glove.trimmed.{embedding_size}.npz)" )
# added
tf .app .flags .DEFINE_string ("model_type" , "lstm " , "specify either gru or lstm cell type for encoding" )
tf .app .flags .DEFINE_string ("model_type" , "gru " , "specify either gru or lstm cell type for encoding" )
tf .app .flags .DEFINE_integer ("debug" , 1 , "whether to set debug or not" )
tf . app . flags . DEFINE_integer ( "grad_clip" , 1 , "whether to clip gradients or not" )
FLAGS = tf .app .flags .FLAGS
@@ -152,7 +152,7 @@ def main(_):
question_encoder = Encoder (size = FLAGS .state_size , vocab_dim = FLAGS .embedding_size , name = "question_encoder" )
context_encoder = Encoder (size = FLAGS .state_size , vocab_dim = FLAGS .embedding_size , name = "context_encoder" )
decoder = Decoder (output_size = FLAGS .output_size )
decoder = Decoder (output_size = FLAGS .output_size , name = "decoder" )
qa = QASystem (encoder = (question_encoder ,context_encoder ),
decoder = decoder ,