Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to use tf.nn.dropout to implement embedding dropout #14746

Closed
zotroneneis opened this issue Nov 21, 2017 · 9 comments
Closed

How to use tf.nn.dropout to implement embedding dropout #14746

zotroneneis opened this issue Nov 21, 2017 · 9 comments

Comments

@zotroneneis
Copy link

zotroneneis commented Nov 21, 2017

Recent papers in language modeling use a specific form of embedding dropout that was proposed in this paper. The paper also proposed variational recurrent dropout which was discussed already in this issue.

In embedding dropout, the same dropout mask is used at each timestep and entire words are dropped (i.e. the whole word vector of a word is set to zero). This behavior can be achieved by providing a noise_shape to tf.nn.dropout. In addition, the same words are dropped throughout a sequence:

"Since we repeat the same mask at each time step, we drop the same words throughout the sequence – i.e. we drop word types at random rather than word tokens (as an example, the sentence “the dog and the cat” might become “— dog and — cat” or “the — and the cat”, but never “— dog and the cat”). "

I couldn't find a way to implement this functionality of embedding dropout efficiently. Are there any plans to incorporate these advances?

@zotroneneis zotroneneis changed the title How to get embedding dropout How to use tf.nn.dropout to implement embedding dropout Nov 21, 2017
@alexeygrigorev
Copy link

@zotroneneis did you figure out how to do it?

@zotroneneis
Copy link
Author

Yes, I did

@alexeygrigorev
Copy link

@zotroneneis can you show a code example?

@zotroneneis
Copy link
Author

zotroneneis commented Jan 5, 2018

Sure! You just have to drop random rows of the embedding matrix (i.e. set them to zero). This corresponds to dropping certain words in the input sequence.

# initialize random embedding matrix
with tf.variable_scope('embedding'):
   self.embedding_matrix = tf.get_variable( "embedding", shape=[self.vocab_size, self.embd_size], dtype=tf.float32, initializer=self.initializer)

with tf.name_scope("embedding_dropout"):
   self.embedding_matrix = tf.nn.dropout(self.embedding_matrix, keep_prob=self.embedding_dropout, noise_shape=[self.vocab_size,1])

with tf.name_scope('input'):
   self.input_batch = tf.placeholder(tf.int64, shape=(None, None))
   self.inputs = tf.nn.embedding_lookup(self.embedding_matrix, self.input_batch)

...


@louiskirsch
Copy link

@zotroneneis But this drops the same words for the entire batch -- this might make learning more noisy, correct?

@makai281
Copy link

makai281 commented May 31, 2018

supposed to be like this?
`
# initialize random embedding matrix
with tf.variable_scope('embedding'):
self.embedding_matrix = tf.get_variable( "embedding", shape=[self.vocab_size, self.embd_size], dtype=tf.float32, initializer=self.initializer)

with tf.name_scope('input'):
self.input_batch = tf.placeholder(tf.int64, shape=(None, None))
self.inputs = tf.nn.embedding_lookup(self.embedding_matrix, self.input_batch)

with tf.name_scope('word_dropout'):
self.inputs = tf.nn.dropout(self.inputs, keep_prob=self.word_dropout, noise_shape=[tf.shape(self.inputs)[0], tf.shape(self.inputs)[1], 1])

`

@louiskirsch
Copy link

louiskirsch commented May 31, 2018

@makai281 No, this would be dropout of the embedding vector (which is also a good idea, but not what we referred to here). We want to dropout inputs (= word types) efficiently.

@simtony
Copy link

simtony commented Sep 4, 2018

Directly drops the whole embedding matrix is inefficient for large vocabulary, since it requires vocab_size*batch_size random variable evaluation.
The implementation below cuts number of random variable evaluation to num_uniq_words_in_batch * batch_size

# preparing inputs
keep_prob = 0.8
ids = tf.convert_to_tensor([[1,2,3], [3,2,1]])
batch_size = tf.shape(ids)[0]
maxlen = tf.shape(ids)[1]
embed_matrix = tf.ones([10, 20])
embed_ids = tf.nn.embedding_lookup(embed_matrix, ids)

uniq_ids, indices = tf.unique(tf.reshape(ids, [-1]))
# generate random mask for each uniq_id
# independent sample for each instance
rand_mask = tf.random_uniform([batch_size, tf.size(uniq_ids)], dtype=embeddings.dtype)

# prepare indices for tf.gather_nd
batch_wise = tf.broadcast_to(tf.expand_dims(tf.range(batch_size), axis=-1), [batch_size, maxlen])
uniq_ids_wise = tf.reshape(indices, [batch_size, maxlen])

# gather mask and convert it to binary mask
mask_indices = tf.stack([batch_wise, uniq_ids_wise], axis=-1)
binary_mask = tf.floor(tf.gather_nd(rand_mask, mask_indices) + keep_prob)

# apply mask and scale
dropped_embeddings = embed_ids * tf.expand_dims(binary_mask, axis=-1) / keep_prob

And here is the result:

In [52]: dropped_embeddings.eval()
Out[52]: 
array([[[1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25,
         1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25],
        [1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25,
         1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25],
        [0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
         0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ]],

       [[1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25,
         1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25],
        [0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
         0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],
        [1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25,
         1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25]]],
      dtype=float32)

In [53]: dropped_embeddings.eval()
Out[53]: 
array([[[1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25,
         1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25],
        [1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25,
         1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25],
        [1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25,
         1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25]],

       [[1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25,
         1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25],
        [1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25,
         1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25],
        [1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25,
         1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25, 1.25]]],
      dtype=float32)

@gzerveas
Copy link

Sure! You just have to drop random rows of the embedding matrix (i.e. set them to zero). This corresponds to dropping certain words in the input sequence.

# initialize random embedding matrix
with tf.variable_scope('embedding'):
   self.embedding_matrix = tf.get_variable( "embedding", shape=[self.vocab_size, self.embd_size], dtype=tf.float32, initializer=self.initializer)

with tf.name_scope("embedding_dropout"):
   self.embedding_matrix = tf.nn.dropout(self.embedding_matrix, keep_prob=self.embedding_dropout, noise_shape=[self.vocab_size,1])

with tf.name_scope('input'):
   self.input_batch = tf.placeholder(tf.int64, shape=(None, None))
   self.inputs = tf.nn.embedding_lookup(self.embedding_matrix, self.input_batch)

...

The biggest problem I see with this, is that typically only a tiny fraction of the overall vocabulary exists in a given batch. Dropping a uniformly random (instead of word frequency-informed) 20% of the vocabulary will most times do nothing. In order to have a non-negligible effect, one would need to increase the dropout to very high values (e.g. 70%), at which point the danger is that some sentences (or batches, since the same words will be dropped for the entire batch) will be dropped almost entirely. So my impression is that training will become noisy. The solution is to drop existing word IDs, preferably separately for each sample in the batch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants