Skip to content

Commit

Permalink
refactor: fix setup for attention
Browse files Browse the repository at this point in the history
  • Loading branch information
jimthompson5802 committed May 9, 2020
1 parent 3310a13 commit 10702d3
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions ludwig/models/modules/sequence_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(
embedding_size=64,
beam_width=1,
num_layers=1,
attention_mechanism=None,
attention=None,
tied_embeddings=None,
initializer=None,
regularize=True,
Expand All @@ -76,7 +76,7 @@ def __init__(
self.embedding_size = embedding_size
self.beam_width = beam_width
self.num_layers = num_layers
self.attention_name = attention_mechanism
self.attention = attention
self.attention_mechanism = None
self.tied_embeddings = tied_embeddings
self.initializer = initializer
Expand Down Expand Up @@ -104,18 +104,17 @@ def __init__(
# Sampler
self.sampler = tfa.seq2seq.sampler.TrainingSampler()

print('setting up attention for', attention_mechanism)
if attention_mechanism is not None:
if attention_mechanism == 'luong':
print('setting up attention for', attention)
if attention is not None:
if attention == 'luong':
self.attention_mechanism = LuongAttention(units=state_size)
elif attention_mechanism == 'bahdanau':
elif attention == 'bahdanau':
self.attention_mechanism = BahdanauAttention(units=state_size)

self.decoder_rnncell = AttentionWrapper(self.decoder_rnncell,
self.attention_mechanism,
attention_layer_size=state_size)


self.decoder = tfa.seq2seq.BasicDecoder(self.decoder_rnncell,
sampler=self.sampler,
output_layer=self.dense_layer)
Expand Down

0 comments on commit 10702d3

Please sign in to comment.