Skip to content
This repository has been archived by the owner on Mar 31, 2019. It is now read-only.

Commit

Permalink
inverse get_mask
Browse files Browse the repository at this point in the history
  • Loading branch information
justheuristic committed May 10, 2017
1 parent 97132e4 commit ce3320e
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions agentnet/learning/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,13 @@ def get_end_indicator(is_alive, force_end_at_t_max=False):

return is_end

def get_mask_by_eos(is_eos):
"""takes indicator of "it ends now", returns mask. Ignores everything after first end.
:param is_eos: indicator that is 0 for all
:type is_eos: theano.matrix
"""
assert is_eos.ndim==2
is_right_after_eos = T.concatenate([T.zeros_like(is_eos[:,:1]),is_eos[:,:-1]],-1)
is_after_eos = T.eq(T.cumsum(is_right_after_eos,axis=-1),0).astype('uint8')
return is_after_eos

0 comments on commit ce3320e

Please sign in to comment.