Skip to content

Commit

Permalink
Beamsearch available
Browse files Browse the repository at this point in the history
  • Loading branch information
yanghoonkim committed Mar 27, 2019
1 parent bd3b929 commit 554fe0a
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 18 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ Implemenration of <Learning to Ask: Neural Question Generation for Reading Co
- GRU/LSTM

- To be updated
- Beam decoder
- Post-processing code for unknown words

2. **Dataset**
Expand Down Expand Up @@ -51,7 +50,7 @@ python process_embedding.py # This will take a couple of minutes
# epochs: training epochs
bash run.sh train [data_name] [hyperparameters] [epochs]
# example : bash run.sh trian squad basic_params 10
# example : bash run.sh train squad basic_params 10
```

4. Test model
Expand Down
55 changes: 44 additions & 11 deletions enc_and_dec.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def run(self, embd_input, sequence_length):
class Decoder(_BaseClass):
def __init__(self, enc_type ='bi',
attn_type = 'bahdanau', voca_size = None,
beam_width = 0, length_penalty_weight = 1,
num_layer = 1, hidden_size = 512,
cell_type = 'lstm', dropout = 0.1,
dtype = tf.float32, mode = tf.estimator.ModeKeys.TRAIN,
Expand All @@ -107,6 +108,8 @@ def __init__(self, enc_type ='bi',
)
self.enc_type = enc_type
self.attn_type = attn_type
self.beam_width = beam_width
self.length_penalty_weight = length_penalty_weight
self.voca_size = voca_size
self.sample_prob = sample_prob

Expand All @@ -129,26 +132,56 @@ def run(self, embd_input, sequence_length, embedding, start_token = 1, end_token
embedding, start_token, end_token
)

# Start decoding
initial_state = self.out_dec_cell.zero_state(dtype = self.dtype, batch_size = self.batch_size)
decoder = tf.contrib.seq2seq.BasicDecoder(
self.out_dec_cell, helper, initial_state,
output_layer = None)
# Decoder initial state setting
if (self.mode != tf.estimator.ModeKeys.PREDICT or self.beam_width == 0):
initial_state = self.out_dec_cell.zero_state(dtype = self.dtype, batch_size = self.batch_size)

decoder = tf.contrib.seq2seq.BasicDecoder(
self.out_dec_cell, helper, initial_state,
output_layer = None)
else:
initial_state = self.out_dec_cell.zero_state(dtype = self.dtype, batch_size = self.batch_size * self.beam_width)
print type(self.length_penalty_weight)
print '----------------------------------'
decoder = tf.contrib.seq2seq.BeamSearchDecoder(
cell = self.out_dec_cell,
embedding = embedding,
start_tokens = start_token,
end_token = end_token,
initial_state = initial_state,
beam_width = self.beam_width,
length_penalty_weight = self.length_penalty_weight)


if self.mode == tf.estimator.ModeKeys.TRAIN:
outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder, impute_finished=True, maximum_iterations = None)
else: # Test & Eval
outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder, impute_finished = True, maximum_iterations = None)
return outputs.rnn_output

# Test with Beam decoding
elif (self.mode == tf.estimator.ModeKeys.PREDICT and self.beam_width > 0):
outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder, impute_finished = False, maximum_iterations = self.max_iter)
predictions = outputs.predicted_ids # [batch, length, beam_width]
predictions = tf.transpose(predictions, [0, 2, 1]) # [batch, beam_width, length]
predictions = predictions = predictions[:, 0, :] # [batch, length]
return predictions


else: # Greedy decoder (Test & Eval)
outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder, impute_finished=True, maximum_iterations = self. max_iter)
return outputs.rnn_output

self.logits = outputs.rnn_output

return self.logits


def set_attentional_cell(self, memory, memory_length, encoder_state, enc_num_layer):
def set_attention_cell(self, memory, memory_length, encoder_state, enc_num_layer):
self.batch_size = tf.shape(memory)[0]

dec_cell = self._create_cell()

if (self.mode == tf.estimator.ModeKeys.PREDICT and self.beam_width > 0):
memory = tf.contrib.seq2seq.tile_batch(memory, self.beam_width)
memory_length = tf.contrib.seq2seq.tile_batch(memory_length, self.beam_width)

attention_mechanism = self._attention(memory, memory_length)

initial_cell_state = encoder_state if self.num_layer == enc_num_layer else None
Expand Down Expand Up @@ -189,4 +222,4 @@ def _attention(self, memory, memory_length):
memory_length,
scale = True)
else:
raise ValueError('Unknown attention mechanism : %s' %attn_type)
raise ValueError('Unknown attention mechanism : %s' %self.attn_type)
16 changes: 11 additions & 5 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def __init__(self, params):
self.maxlen_dec_dev = params['maxlen_dec_dev'] # for loss calculation
self.rnn_dropout = params['dropout']
self.attn = params['attn']
self.beam_width = params['beam_width']
self.length_penalty_weight = params['length_penalty_weight']
self.sample_prob = params['sample_prob']
self.learning_rate = params['learning_rate']
self.decay_step = params['decay_step'] # learning rate decay
Expand Down Expand Up @@ -61,16 +63,20 @@ def run(self, features, labels, mode, params):
with tf.variable_scope('DecoderScope'):
decoder = enc_and_dec.Decoder(self.enc_type,
self.attn, self.voca_size,
self.beam_width, self.length_penalty_weight,
self.dec_layer, self.hidden_size * 2 * (self.enc_type == 'bi'),
self.cell_type, self.rnn_dropout,
self.dtype, mode, self.sample_prob)

# Add attention wrapper to decoder cell
decoder.set_attentional_cell(encoder_outputs, self.enc_input_length, encoder_state, self.enc_layer)

self.logits = decoder.run(self.embd_dec_inputs, self.dec_input_length, self.dec_embedding, self.GO, self.EOS)
self.predictions = tf.argmax(self.logits, axis = -1)
#self.predictions = tf.Print(self.predictions, [self.predictions], message = '------********thisisit*****8--------')
decoder.set_attention_cell(encoder_outputs, self.enc_input_length, encoder_state, self.enc_layer)

if not (mode == tf.estimator.ModeKeys.PREDICT and self.beam_width > 0):
self.logits = decoder.run(self.embd_dec_inputs, self.dec_input_length, self.dec_embedding, self.GO, self.EOS)
self.predictions = tf.argmax(self.logits, axis = -1)

else: # Beam decoding
self.predictions = decoder.run(self.embd_dec_inputs, self.dec_input_length, self.dec_embedding, self.GO, self.EOS)

self._calculate_loss(mode)
return self._update_or_output(mode)
Expand Down
2 changes: 2 additions & 0 deletions params.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def basic_params():
decoder_layer = 2,
dropout = 0.4,
attn = 'normed_bahdanau', # 'bahdanau', 'normed_bahdanau', 'luong', 'scaled_luong'
beam_width = 5,
length_penalty_weight = 2.1,

# Extra params
dtype = tf.float32,
Expand Down
1 change: 1 addition & 0 deletions run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ squad(){
DEV_QUESTION='data/processed/dev_question.npy'
TEST_SENTENCE='data/processed/test_sentence.npy'
PRED_DIR='result/pred.txt'
DIC_DIR='data/processed/vocab_xinyadu.dic'
}


Expand Down

0 comments on commit 554fe0a

Please sign in to comment.