Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RNN's state_is_tuple doesn't work with initial_state #2695

Closed
bogdanteleaga opened this issue Jun 6, 2016 · 6 comments
Closed

RNN's state_is_tuple doesn't work with initial_state #2695

bogdanteleaga opened this issue Jun 6, 2016 · 6 comments

Comments

@bogdanteleaga
Copy link

bogdanteleaga commented Jun 6, 2016

Assuming I want to batch series of inputs and propagate the cell state from one session run towards another for an epoch:

for batch in epoch:
  state = initial_state.eval()
  feed_dict = {initial_state: state}
  state = sess.run([final_state], feed_dict)

Since using state_is_tuple in the cells makes the state be a tuple on return:

  • using .eval() doesn't work for an initial state
  • subsequent states are returned as tuples and cannot be fed back into the session as tuples
@bogdanteleaga bogdanteleaga changed the title RNN's state_is_tuple doesn RNN's state_is_tuple doesn't work with initial_state Jun 6, 2016
@jihunchoi
Copy link
Contributor

When you use state_is_tuple, it literally means that the type of state is tuple.
However, the first parameter should be a list of tensors and the second should be a dictionary whose key is of tensor type.
As far as I know, you have to separate the state tuple.
So your code should be like in the below form: (I haven't tested)

for batch in epoch:
    state_c = initial_state[0].eval()
    state_m = initial_state[1].eval()
    feed_dict = {initial_state[0]: state_c, initial_state[1]: state_m}
    state_c, state_m = sess.run([final_state[0], final_state[1]], feed_dict)

Note that it is applicable only for simple RNNs without stacked layers; for a multi-layer case, check out my modification of the RNN PTB example: https://github.com/jihunchoi/tensorflow/blob/ptb_use_state_tuple/tensorflow/models/rnn/ptb/ptb_word_lm.py.

@bogdanteleaga
Copy link
Author

Thank you for the example, it did work and it is indeed faster :)
It does seem to be quite a struggle to get it exactly right, perhaps being able to pass tuples into the feed dict would help.

@ebrevdo
Copy link
Contributor

ebrevdo commented Jun 13, 2016

This may be something we're working on. Would you like to open a new bug to track it, explicitly describing the semantics of how you'd like to be able to pass tuples into feed_dict? Closing this bug for now. Thanks for answering @jihunchoi!

@memo
Copy link
Contributor

memo commented Aug 2, 2016

Sorry to comment on this closed issue, but I came across the exact same problem when trying to get rid of the 'state_is_tuple' warning. The usage of the tuple-less state was very nice and manageable (at least for the basic stuff I was doing), passing in and out a single tensor. With the tuple state, if I have a variable number of cells (e.g. trying different hyperparameters), the code becomes a bit more ugly. I wrote something like this below, which returns a dict which I can use for feeding into the initial state. But then the final state also becomes a problem, and I'm not sure what the best way to manage this dynamically is. is this really the best way?

def init_state(self):
    states_dict = {}
    for layer in self.initial_states_:
        for state in layer:
            states_dict[state] = state.eval()
    return states_dict

@memo
Copy link
Contributor

memo commented Aug 3, 2016

just a follow up on this. I think I have it fully working now (tested a small model). For a Graves-style sequence generation it took quite a bit of wrangling.

This was the original code (relevant bits only) without tuples

state = model.initial_state.eval()
loop:
    feed = {model.inputs:x, model.initial_state: state}
    loss, state, _ = sess.run([model.outputs, model.state], feed)

and this is what it took (relevant bits only) to get it working with tuples, it would be great to wrap this up somehow, and make it easier:

def get_states_list(states, state_is_tuple):
    """
    given a 'states' variable from a tensorflow model,
    return a flattened list of states
    """
    if state_is_tuple:
        states_list = [] # flattened list of all tensors in states
        for layer in states:
            for state in layer:
                states_list.append(state)
    else:
        states_list = [states]

    return states_list


def get_states_dict(states, state_is_tuple):
    """
    given a 'states' variable from a tensorflow model,
    return a dict of { tensor : evaluated value }
    """
    if state_is_tuple:
        states_dict = {} # dict of { tensor : value } 
        for layer in states:
            for state in layer:
                states_dict[state] = state.eval()

    else:
        states_dict = {states : states.eval()}

    return states_dict

init_states_list = utils.get_states_list(model.initial_state, model.state_is_tuple)
init_states_dict = utils.get_states_dict(model.initial_state, model.state_is_tuple)
final_states_list = utils.get_states_list(model.final_state, model.state_is_tuple)

states_dict = init_states_dict

loop:
    feed = {model.inputs:x}
    feed.update(states_dict)
    fetch = [model.outputs] + final_states_list
    ret = sess.run(fetch, feed)
    outputs = ret[0]
    states = ret[1:]
    states_dict = dict(zip(init_states_list, states))

@zsdonghao
Copy link

Hi , there are a tutorial_ptb_lstm_state_is_tuple in TensorLayer repo, hope it help.
https://github.com/zsdonghao/tensorlayer

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants