Skip to content

Tensorflow 1.5.0 breaking previously built models #17045

@nikita6187

Description

@nikita6187

Hello!

I've recently updated to tensorflow version 1.5.0, and suddenly receive an error, that I can't decipher, for code that worked before (in version 1.4.1): Cannot use 'transducer_training/while/rnn/strided_slice' as input to 'gradients/transducer_training/while/rnn/while/Select_1_grad/Select/f_acc' because 'transducer_training/while/rnn/strided_slice' is in a while loop

I've also tried using the softmax_cross_entropy_with_logits function, but that still produced the same error. Here's the stackoverflow post, in case its a coding mistake on my part.
The model is a seq2seq variation.


System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 16.04
  • TensorFlow installed from (source or binary): Binary
  • TensorFlow version (use command below): 1.5.0 (previous 1.4.1)
  • Python version: 2.7
  • Bazel version (if compiling from source): N/A
  • GCC/Compiler version (if compiling from source): N/A
  • CUDA/cuDNN version: N/A
  • GPU model and memory: N/A
  • Exact command to reproduce: Copy, paste and run the code
import tensorflow as tf
from tensorflow.contrib.rnn import LSTMCell, LSTMStateTuple
from tensorflow.python.layers import core as layers_core

# NOTE: Time major

# ---------------- Constants Manager ----------------------------
class ConstantsManager(object):
    def __init__(self, input_dimensions, input_embedding_size, inputs_embedded, encoder_hidden_units,
                 transducer_hidden_units, vocab_ids, input_block_size, beam_width):
        assert transducer_hidden_units == encoder_hidden_units, 'Encoder and transducer have to have the same amount' \
                                                                'of hidden units'
        self.input_dimensions = input_dimensions
        self.vocab_ids = vocab_ids
        self.E_SYMBOL = len(self.vocab_ids)
        self.vocab_ids.append('E_SYMBOL')
        self.GO_SYMBOL = len(self.vocab_ids)
        self.vocab_ids.append('GO_SYMBOL')
        self.vocab_size = len(self.vocab_ids)
        self.input_embedding_size = input_embedding_size
        self.inputs_embedded = inputs_embedded
        self.encoder_hidden_units = encoder_hidden_units
        self.transducer_hidden_units = transducer_hidden_units
        self.input_block_size = input_block_size
        self.beam_width = beam_width
        self.batch_size = 1  # Cannot be increased, see paper
        self.log_prob_init_value = 0

# ----------------- Model ---------------------------------------


class Model(object):
    def __init__(self, cons_manager):
        self.var_list = []
        self.cons_manager = cons_manager
        self.max_blocks, self.inputs_full_raw, self.transducer_list_outputs, self.start_block, self.encoder_hidden_init,\
            self.trans_hidden_init, self.logits, self.encoder_hidden_state_new, \
            self.transducer_hidden_state_new, self.train_saver = self.build_full_transducer()

        self.targets, self.train_op, self.loss = self.build_training_step()

    def build_full_transducer(self):
        with tf.variable_scope('transducer_training'):

            embeddings = tf.Variable(tf.random_uniform([self.cons_manager.vocab_size,
                                                        self.cons_manager.input_embedding_size], -1.0, 1.0),
                                     dtype=tf.float32,
                                     name='embedding')
            # Inputs
            max_blocks = tf.placeholder(dtype=tf.int32, name='max_blocks')  # total amount of blocks to go through
            if self.cons_manager.inputs_embedded is True:
                input_type = tf.float32
            else:
                input_type = tf.int32
            inputs_full_raw = tf.placeholder(shape=(None, self.cons_manager.batch_size,
                                                    self.cons_manager.input_dimensions), dtype=input_type,
                                             name='inputs_full_raw')  # shape [max_time, 1, input_dims]
            transducer_list_outputs = tf.placeholder(shape=(None,), dtype=tf.int32,
                                                     name='transducer_list_outputs')  # amount to output per block
            start_block = tf.placeholder(dtype=tf.int32, name='transducer_start_block')  # where to start the input

            encoder_hidden_init = tf.placeholder(shape=(2, 1, self.cons_manager.encoder_hidden_units), dtype=tf.float32,
                                                 name='encoder_hidden_init')
            trans_hidden_init = tf.placeholder(shape=(2, 1, self.cons_manager.transducer_hidden_units), dtype=tf.float32,
                                               name='trans_hidden_init')

            # Temporary constants, maybe changed during inference
            end_symbol = tf.get_variable(name='end_symbol',
                                         initializer=tf.constant_initializer(self.cons_manager.vocab_size),
                                         shape=(), dtype=tf.int32)

            # Turn inputs into tensor which is easily readable#

            inputs_full = tf.reshape(inputs_full_raw, shape=[-1, self.cons_manager.input_block_size,
                                                             self.cons_manager.batch_size,
                                                             self.cons_manager.input_dimensions])

            # Outputs
            outputs_ta = tf.TensorArray(dtype=tf.float32, size=max_blocks)

            init_state = (start_block, outputs_ta, encoder_hidden_init, trans_hidden_init)

            # Initiate cells, NOTE: if there is a future error, put these back inside the body function
            encoder_cell = tf.contrib.rnn.LSTMCell(num_units=self.cons_manager.encoder_hidden_units)
            transducer_cell = tf.contrib.rnn.LSTMCell(self.cons_manager.transducer_hidden_units)

            def cond(current_block, outputs_int, encoder_hidden, trans_hidden):
                return current_block < start_block + max_blocks

            def body(current_block, outputs_int, encoder_hidden, trans_hidden):

                # --------------------- ENCODER ----------------------------------------------------------------------
                encoder_inputs = inputs_full[current_block]
                encoder_inputs_length = [tf.shape(encoder_inputs)[0]]
                encoder_hidden_state = encoder_hidden

                if self.cons_manager.inputs_embedded is True:
                    encoder_inputs_embedded = encoder_inputs
                else:
                    encoder_inputs = tf.reshape(encoder_inputs, shape=[-1, self.cons_manager.batch_size])
                    encoder_inputs_embedded = tf.nn.embedding_lookup(embeddings, encoder_inputs)

                # Build model

                # Build previous state
                encoder_hidden_c, encoder_hidden_h = tf.split(encoder_hidden_state, num_or_size_splits=2, axis=0)
                encoder_hidden_c = tf.reshape(encoder_hidden_c, shape=[-1, self.cons_manager.encoder_hidden_units])
                encoder_hidden_h = tf.reshape(encoder_hidden_h, shape=[-1, self.cons_manager.encoder_hidden_units])
                encoder_hidden_state_t = LSTMStateTuple(encoder_hidden_c, encoder_hidden_h)

                #   encoder_outputs: [max_time, batch_size, num_units]
                encoder_outputs, encoder_hidden_state_new = tf.nn.dynamic_rnn(
                    encoder_cell, encoder_inputs_embedded,
                    sequence_length=encoder_inputs_length, time_major=True,
                    dtype=tf.float32, initial_state=encoder_hidden_state_t)

                # Modify output of encoder_hidden_state_new so that it can be fed back in again without problems.
                encoder_hidden_state_new = tf.concat([encoder_hidden_state_new.c, encoder_hidden_state_new.h], axis=0)
                encoder_hidden_state_new = tf.reshape(encoder_hidden_state_new,
                                                      shape=[2, -1, self.cons_manager.encoder_hidden_units])

                # --------------------- TRANSDUCER --------------------------------------------------------------------
                encoder_raw_outputs = encoder_outputs
                # Save/load the state as one tensor, use encoder state as init if this is the first block
                trans_hidden_state = tf.cond(current_block > 0, lambda: trans_hidden, lambda: encoder_hidden_state_new)
                transducer_amount_outputs = transducer_list_outputs[current_block - start_block]

                # Model building
                helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
                    embedding=embeddings,
                    start_tokens=tf.tile([self.cons_manager.GO_SYMBOL],
                                         [self.cons_manager.batch_size]),  # TODO: check if this looks good
                    end_token=end_symbol)  # vocab size, so that it doesn't prematurely end the decoding

                attention_states = tf.transpose(encoder_raw_outputs,
                                                [1, 0, 2])  # attention_states: [batch_size, max_time, num_units]

                attention_mechanism = tf.contrib.seq2seq.LuongAttention(
                    self.cons_manager.encoder_hidden_units, attention_states)

                decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
                    transducer_cell,
                    attention_mechanism,
                    attention_layer_size=self.cons_manager.transducer_hidden_units)

                projection_layer = layers_core.Dense(self.cons_manager.vocab_size, use_bias=False)

                # Build previous state
                trans_hidden_c, trans_hidden_h = tf.split(trans_hidden_state, num_or_size_splits=2, axis=0)
                trans_hidden_c = tf.reshape(trans_hidden_c, shape=[-1, self.cons_manager.transducer_hidden_units])
                trans_hidden_h = tf.reshape(trans_hidden_h, shape=[-1, self.cons_manager.transducer_hidden_units])
                trans_hidden_state_t = LSTMStateTuple(trans_hidden_c, trans_hidden_h)

                decoder = tf.contrib.seq2seq.BasicDecoder(
                    decoder_cell, helper,
                    decoder_cell.zero_state(1, tf.float32).clone(cell_state=trans_hidden_state_t),
                    output_layer=projection_layer)

                outputs, transducer_hidden_state_new, _ = tf.contrib.seq2seq.dynamic_decode(decoder,
                                                                                            output_time_major=True,
                                                                                            maximum_iterations=transducer_amount_outputs)
                logits = outputs.rnn_output  # logits of shape [max_time,batch_size,vocab_size]
                decoder_prediction = outputs.sample_id  # For debugging

                # Modify output of transducer_hidden_state_new so that it can be fed back in again without problems.
                transducer_hidden_state_new = tf.concat(
                    [transducer_hidden_state_new[0].c, transducer_hidden_state_new[0].h],
                    axis=0)
                transducer_hidden_state_new = tf.reshape(transducer_hidden_state_new,
                                                         shape=[2, -1, self.cons_manager.transducer_hidden_units])


                # Note the outputs
                outputs_int = outputs_int.write(current_block - start_block, logits)

                return current_block + 1, outputs_int, encoder_hidden_state_new, transducer_hidden_state_new

            _, outputs_final, encoder_hidden_state_new, transducer_hidden_state_new = \
                tf.while_loop(cond, body, init_state, parallel_iterations=1)

            # Process outputs
            outputs = outputs_final.concat()
            logits = tf.reshape(
                outputs,
                shape=(-1, 1, self.cons_manager.vocab_size))  # And now its [max_output_time, batch_size, vocab]

            # For loading the model later on
            logits = tf.identity(logits, name='logits')
            encoder_hidden_state_new = tf.identity(encoder_hidden_state_new, name='encoder_hidden_state_new')
            transducer_hidden_state_new = tf.identity(transducer_hidden_state_new, name='transducer_hidden_state_new')

        train_saver = tf.train.Saver()  # For now save everything

        return max_blocks, inputs_full_raw, transducer_list_outputs, start_block, encoder_hidden_init,\
            trans_hidden_init, logits, encoder_hidden_state_new, transducer_hidden_state_new, train_saver

    def build_training_step(self):
        targets = tf.placeholder(shape=(None,), dtype=tf.int32, name='targets')
        targets_one_hot = tf.one_hot(targets, depth=self.cons_manager.vocab_size, dtype=tf.float32)

        targets_one_hot = tf.Print(targets_one_hot, [targets], message='Targets: ', summarize=10)
        targets_one_hot = tf.Print(targets_one_hot, [tf.argmax(self.logits, axis=2)], message='Argmax: ', summarize=10)

        stepwise_cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels=targets_one_hot,
                                                                         logits=self.logits)
        loss = tf.reduce_mean(stepwise_cross_entropy)
        train_op = tf.train.AdamOptimizer().minimize(loss)
        return targets, train_op, loss


constants_manager = ConstantsManager(input_dimensions=1, input_embedding_size=11, inputs_embedded=False,
                                     encoder_hidden_units=100, transducer_hidden_units=100, vocab_ids=[0, 1, 2],
                                     input_block_size=1, beam_width=5)
model = Model(cons_manager=constants_manager)

I can try and make a smaller fail case if needed.

Thanks!
Nikita

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions