In [1]:
import tensorflow as tf
import numpy as np

In [8]:
np.random.seed(10)
# maximum number of time steps
T=8
# batch size
B = 4
rnn_dim = 128
num_classes = 10

# length of examples in 4 batches
example_len = [1,2,3,8]

# outputs (random = dummy labels)
y = np.random.randint(1,10, [B,T]) # [batch_size x timesteps]
for i, length in enumerate(example_len):
    y[i, length:] = 0
y = y.astype(np.int64)

In [16]:
tf.reset_default_graph()
tf.set_random_seed(10)

rnn_outputs = tf.convert_to_tensor(np.random.randn(B,T,rnn_dim), dtype=np.float32)
w = tf.get_variable(name='w',
                   initializer=tf.random_normal_initializer(),
                   shape = [rnn_dim, num_classes])


In [17]:
# flatten rnn_outputs
rnn_outputs_flat = tf.reshape(rnn_outputs,[-1, rnn_dim])
# flat logits
logits_flat = tf.batch_matmul(rnn_outputs_flat,w)
# flat probabilities
probs_flat = tf.nn.softmax(logits_flat)


In [18]:
# calculate loss
#   reshape y placeholder to calc loss
y_flat = tf.reshape(y,[-1])
losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits_flat,y_flat)

In [22]:
# THE MASK.. finally
mask = tf.sign(tf.to_float(y_flat))
# masked losses
masked_losses = mask * losses
# now back to normal shape
masked_losses = tf.reshape(masked_losses,tf.shape(y))

In [23]:
# mean loss
mean_loss_by_example = tf.reduce_sum(masked_losses, reduction_indices=1)/example_len
mean_loss = tf.reduce_mean(mean_loss_by_example)

In [24]:
results = tf.contrib.learn.run_n({
        'masked_losses' : masked_losses,
        'mean_loss_by_example' : mean_loss_by_example,
        'mean_loss' : mean_loss
    })

In [30]:
results[0]['masked_losses']
# yup!!! :)

array([[  7.53408384e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00,   0.00000000e+00],
       [  1.45501976e+01,   1.90733044e-05,   0.00000000e+00,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00,   0.00000000e+00],
       [  2.30566845e+01,   1.94982891e+01,   3.09809460e-03,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00,
          0.00000000e+00,   0.00000000e+00],
       [  1.12082520e+01,   4.99081879e+01,   3.27171707e+01,
          5.38984680e+01,   8.67770576e+00,   6.23956108e+00,
          1.49090614e+01,   1.77541542e+01]], dtype=float32)