-
Notifications
You must be signed in to change notification settings - Fork 0
/
visual_model.py
180 lines (152 loc) · 7.84 KB
/
visual_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import os
import time
import tensorflow as tf
from nmt.utils import vocab_utils
import iterator_utils
from nmt import model_helper
from nmt.utils import misc_utils as utils
from tensorflow.python.layers.core import Dense
import helper as helper_utils
from simple_model import SimpleAttentionModel
class VisualModel ( SimpleAttentionModel ):
def _create_decoder_initial_state(self, cell, hparams, batch_size, encoder_state):
"""
batch_size depends on the beam_width
"""
if hparams.pass_hidden_state:
decoder_initial_state = cell.zero_state(batch_size, dtype).copy(
cell_state=encoder_state)
else:
decoder_initial_state = cell.zero_state(batch_size, dtype)
return decoder_initial_state
def _build_decoder(self, encoder_outputs, encoder_state, hparams):
"""Build and run a RNN decoder with a final projection layer.
Args:
encoder_outputs: The outputs of encoder for every time step.
encoder_state: The final state of the encoder.
hparams: The Hyperparameters configurations.
Returns:
A tuple of final logits and final decoder state:
logits: size [time, batch_size, vocab_size] when time_major=True.
"""
### Start and end of sequence
tgt_sos_id = tf.cast(self.tgt_vocab_table.lookup(tf.constant(hparams.sos)),
tf.int32)
tgt_eos_id = tf.cast(self.tgt_vocab_table.lookup(tf.constant(hparams.eos)),
tf.int32)
iterator = self.iterator
maximum_iterations = hparams.tgt_max_len_infer
utils.print_out(" decoding maximum_iterations %d" % maximum_iterations)
## Decoder.
with tf.variable_scope("decoder") as decoder_scope:
# decoder_initial_state is basically zeros
# This is different from encoder-decoder framework
# in which state of encoder is passed into decoder
# This part is just a reflection of how lazy I am
# Instead of having different values of num_units
# One for encoder, one for decoder
# I just use one, and change it accordingly
# before passing into _build_decoder_cell
hparams.num_units += hparams.visual_size
cell, decoder_initial_state = self._build_decoder_cell(
hparams, encoder_outputs, encoder_state,
iterator.source_sequence_length)
# Now preserve it, in case we might have to use it again
hparams.num_units -= hparams.visual_size
## Train or eval
if self.mode != tf.contrib.learn.ModeKeys.INFER:
# decoder_emp_inp: [max_time, batch_size, num_units]
target_input = iterator.target_input
# target_visual: [max_time, batch_size, visual_size]
target_visual = iterator.target_visual
if self.time_major:
target_input = tf.transpose(target_input)
target_visual = tf.transpose(target_visual)
# decoder_emb_inp.get_shape() = [max_time, batch_size, num_units]
decoder_emb_inp = tf.nn.embedding_lookup(
self.embedding_decoder, target_input)
# concatenated_input.get_shape() = [max_time, batch_size,
# num_units + visual_size]
concatenated_input = tf.concat ([decoder_emb_inp, target_visual],
axis = 2)
# Helper
helper = helper_utils.TrainingHelper(
concatenated_input, iterator.target_sequence_length,
time_major=self.time_major)
# Decoder
my_decoder = tf.contrib.seq2seq.BasicDecoder(
cell,
helper,
decoder_initial_state,)
# Dynamic decoding
outputs, final_context_state, _ = tf.contrib.seq2seq.dynamic_decode(
my_decoder,
output_time_major=self.time_major,
swap_memory=True,
scope=decoder_scope)
sample_id = outputs.sample_id
# Note from original code: there's a subtle difference here between train and inference.
# We could have set output_layer when create my_decoder
# and shared more code between train and inference.
# We chose to apply the output_layer to all timesteps for speed:
# 10% improvements for small models & 20% for larger ones.
# If memory is a concern, we should apply output_layer per timestep.
# Tuan's note: self.output_layer is a Dense layer predicting
# an output with a number of predicting classes.
# outputs.rnn_output has a size of [time, batch_size, cell_size]
logits = self.output_layer(outputs.rnn_output)
## Inference
else:
beam_width = hparams.beam_width
length_penalty_weight = hparams.length_penalty_weight
start_tokens = tf.fill([self.batch_size], tgt_sos_id)
end_token = tgt_eos_id
if beam_width > 0:
my_decoder = tf.contrib.seq2seq.BeamSearchDecoder(
cell=cell,
embedding=self.embedding_decoder,
start_tokens=start_tokens,
end_token=end_token,
initial_state=decoder_initial_state,
beam_width=beam_width,
output_layer=self.output_layer,
length_penalty_weight=length_penalty_weight)
else:
# Beam_width might not be very important in this problem
# But I should include it to make a comparison to the reinforcement
# learning model
# Helper
sampling_temperature = hparams.sampling_temperature
# Uses sampling (from a distribution) instead of argmax and
# passes the result through an embedding layer to get the next input.
# sampling_temperature control the level of randomness (or argmax*ness*)
if sampling_temperature > 0.0:
helper = tf.contrib.seq2seq.SampleEmbeddingHelper(
self.embedding_decoder, start_tokens, end_token,
softmax_temperature=sampling_temperature,
seed=hparams.random_seed)
else:
helper = helper_utils.ControllerGreedyEmbeddingHelper(
self.embedding_decoder, start_tokens, end_token)
# Decoder
my_decoder = tf.contrib.seq2seq.BasicDecoder(
cell,
helper,
decoder_initial_state,
output_layer=self.output_layer # applied per timestep
)
# Dynamic decoding
outputs, final_context_state, _ = tf.contrib.seq2seq.dynamic_decode(
my_decoder,
maximum_iterations=maximum_iterations,
output_time_major=self.time_major,
swap_memory=True,
scope=decoder_scope)
if beam_width > 0:
logits = tf.no_op()
sample_id = outputs.predicted_ids
else:
# This logits has been run through the dense self.output_layer
logits = outputs.rnn_output
sample_id = outputs.sample_id
return logits, sample_id, final_context_state