In [5]:
import numpy as np
import tensorflow as tf

# Extract Rows

This would be applied before the data is split among batches.

In [6]:
data = tf.reshape(tf.range(12, dtype=float), (4,3))
data

<tf.Tensor: shape=(4, 3), dtype=float32, numpy=
array([[ 0.,  1.,  2.],
       [ 3.,  4.,  5.],
       [ 6.,  7.,  8.],
       [ 9., 10., 11.]], dtype=float32)>

In [7]:
policies = tf.constant([0.,1.,0.,1.])
policies

<tf.Tensor: shape=(4,), dtype=float32, numpy=array([0., 1., 0., 1.], dtype=float32)>

In [8]:
tf.gather(data, tf.squeeze(tf.where(policies==0.)))

<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[0., 1., 2.],
       [6., 7., 8.]], dtype=float32)>

# Setup

In [7]:
rex_beta = 0.1

In [8]:
losses_arr = tf.reshape(tf.range(24, dtype=float), (4,2,3))
losses = tf.constant(losses_arr)
losses

<tf.Tensor: shape=(4, 2, 3), dtype=float32, numpy=
array([[[ 0.,  1.,  2.],
        [ 3.,  4.,  5.]],

       [[ 6.,  7.,  8.],
        [ 9., 10., 11.]],

       [[12., 13., 14.],
        [15., 16., 17.]],

       [[18., 19., 20.],
        [21., 22., 23.]]], dtype=float32)>

In [9]:
input = tf.constant(tf.reshape(tf.range(24, dtype=float), (4,2,3)))
input

<tf.Tensor: shape=(4, 2, 3), dtype=float32, numpy=
array([[[ 0.,  1.,  2.],
        [ 3.,  4.,  5.]],

       [[ 6.,  7.,  8.],
        [ 9., 10., 11.]],

       [[12., 13., 14.],
        [15., 16., 17.]],

       [[18., 19., 20.],
        [21., 22., 23.]]], dtype=float32)>

In [10]:
# policies = tf.constant([
#     [
#         [1.],
#         [1.]
#     ],
#     [
#         [1.],
#         [0.]
#     ],
#     [
#         [0.],
#         [2.]
#     ],
#     [
#         [0.],
#         [1.]
#     ]

# ])
# policies

In [11]:
policies = tf.constant([
    [
        [1.],
        [1.]
    ],
    [
        [1.],
        [0.]
    ],
    [
        [0.],
        [2.]
    ],
    [
        [0.],
        [1.]
    ]

])
policies

<tf.Tensor: shape=(4, 2, 1), dtype=float32, numpy=
array([[[1.],
        [1.]],

       [[1.],
        [0.]],

       [[0.],
        [2.]],

       [[0.],
        [1.]]], dtype=float32)>

In [12]:
obs_losses = tf.reduce_mean(losses, axis=-1, keepdims=True)
obs_losses

<tf.Tensor: shape=(4, 2, 1), dtype=float32, numpy=
array([[[ 1.],
        [ 4.]],

       [[ 7.],
        [10.]],

       [[13.],
        [16.]],

       [[19.],
        [22.]]], dtype=float32)>

In [13]:
unique_pols = tf.unique(tf.reshape(policies, [-1])).y
unique_pols

<tf.Tensor: shape=(3,), dtype=float32, numpy=array([1., 0., 2.], dtype=float32)>

In [27]:
tf.where(policies==0.)[:,:2]

<tf.Tensor: shape=(3, 2), dtype=int64, numpy=
array([[1, 1],
       [2, 0],
       [3, 0]])>

In [26]:
tf.gather_nd(losses, tf.where(policies==0.)[:,:2], batch_dims=0)

<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[ 9., 10., 11.],
       [12., 13., 14.],
       [18., 19., 20.]], dtype=float32)>

In [18]:
tf.gather_nd(losses, tf.cast(policies, tf.int32))[:,:,1,:]

<tf.Tensor: shape=(4, 2, 3), dtype=float32, numpy=
array([[[ 9., 10., 11.],
        [ 9., 10., 11.]],

       [[ 9., 10., 11.],
        [ 3.,  4.,  5.]],

       [[ 3.,  4.,  5.],
        [15., 16., 17.]],

       [[ 3.,  4.,  5.],
        [ 9., 10., 11.]]], dtype=float32)>

# Solution

In [9]:
policies = tf.cast(policies, tf.int32)
policies

<tf.Tensor: shape=(4, 2, 1), dtype=int32, numpy=
array([[[1],
        [1]],

       [[1],
        [0]],

       [[0],
        [2]],

       [[0],
        [1]]], dtype=int32)>

In [10]:
unique_pols = tf.unique(tf.reshape(policies, [-1])).y
unique_pols

<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 0, 2], dtype=int32)>

In [11]:
tf.reduce_max(unique_pols+1)

<tf.Tensor: shape=(), dtype=int32, numpy=3>

In [12]:
pol_one_hot = tf.squeeze(tf.one_hot(policies, tf.reduce_max(unique_pols+1), axis=-1), axis=-2)
pol_one_hot

<tf.Tensor: shape=(4, 2, 3), dtype=float32, numpy=
array([[[0., 1., 0.],
        [0., 1., 0.]],

       [[0., 1., 0.],
        [1., 0., 0.]],

       [[1., 0., 0.],
        [0., 0., 1.]],

       [[1., 0., 0.],
        [0., 1., 0.]]], dtype=float32)>

In [13]:
pol_mean_sum = tf.squeeze(tf.matmul(tf.transpose(pol_one_hot, [0,2,1]), obs_losses), axis=-1)
pol_mean_sum

<tf.Tensor: shape=(4, 3), dtype=float32, numpy=
array([[ 0.,  5.,  0.],
       [10.,  7.,  0.],
       [13.,  0., 16.],
       [19., 22.,  0.]], dtype=float32)>

In [14]:
pol_count = tf.reduce_sum(pol_one_hot, axis=-2)
pol_count

<tf.Tensor: shape=(4, 3), dtype=float32, numpy=
array([[0., 2., 0.],
       [1., 1., 0.],
       [1., 0., 1.],
       [1., 1., 0.]], dtype=float32)>

In [15]:
policy_losses = tf.math.divide_no_nan(pol_mean_sum, pol_count)
policy_losses

<tf.Tensor: shape=(4, 3), dtype=float32, numpy=
array([[ 0. ,  2.5,  0. ],
       [10. ,  7. ,  0. ],
       [13. ,  0. , 16. ],
       [19. , 22. ,  0. ]], dtype=float32)>

In [16]:
total_loss = tf.reduce_sum(policy_losses,axis=-1)
total_loss

<tf.Tensor: shape=(4,), dtype=float32, numpy=array([ 2.5, 17. , 29. , 41. ], dtype=float32)>

In [17]:
np.mean(np.arange(0.,6.)), np.mean(np.arange(6.,12.)), np.mean(np.arange(12.,18.)), np.mean(np.arange(18.,24.))

(2.5, 8.5, 14.5, 20.5)

In [18]:
tf.math.reduce_sum(tf.ragged.boolean_mask(policy_losses, pol_count>0.),axis=-1)

<tf.Tensor: shape=(4,), dtype=float32, numpy=array([ 2.5, 17. , 29. , 41. ], dtype=float32)>

In [19]:
loss_var = tf.math.reduce_variance(tf.ragged.boolean_mask(policy_losses, pol_count>0.), axis=-1)
loss_var

<tf.Tensor: shape=(4,), dtype=float32, numpy=array([0.  , 2.25, 2.25, 2.25], dtype=float32)>

In [36]:
pol_count

<tf.Tensor: shape=(4, 3), dtype=float32, numpy=
array([[0., 2., 0.],
       [1., 1., 0.],
       [1., 0., 1.],
       [1., 1., 0.]], dtype=float32)>

In [37]:
tf.stack((policy_losses, pol_count), axis=-2)

<tf.Tensor: shape=(4, 2, 3), dtype=float32, numpy=
array([[[ 0. ,  2.5,  0. ],
        [ 0. ,  2. ,  0. ]],

       [[10. ,  7. ,  0. ],
        [ 1. ,  1. ,  0. ]],

       [[13. ,  0. , 16. ],
        [ 1. ,  0. ,  1. ]],

       [[19. , 22. ,  0. ],
        [ 1. ,  1. ,  0. ]]], dtype=float32)>

In [39]:
def determine_var(x):
    pol_losses, pol_counts = x[0,:], x[1,:]
    return tf.math.reduce_variance(tf.boolean_mask(pol_losses, pol_counts>0.))

In [40]:
tf.map_fn(determine_var, tf.stack((policy_losses, pol_count), axis=-2))

<tf.Tensor: shape=(4,), dtype=float32, numpy=array([0.  , 2.25, 2.25, 2.25], dtype=float32)>

In [21]:
total_loss_var = total_loss + loss_var
total_loss_var

<tf.Tensor: shape=(4,), dtype=float32, numpy=array([ 2.5 , 19.25, 31.25, 43.25], dtype=float32)>

# Alternate

In [None]:
tf.concat((losses, policies), axis=-1)

In [None]:
def process_member(x):
    data, pol = x[:,:-1], x[:,-1:]
    def process_pol(y):
        mask = tf.tile(tf.equal(pol,y),(1,data.shape[-1]))
        return tf.reduce_mean(tf.boolean_mask(data, mask))
    return tf.map_fn(process_pol, unique_pols)

In [None]:
result = tf.map_fn(process_member, tf.concat((losses, policies), axis=-1))

In [None]:
result

In [None]:
tf.reduce_sum(result, axis=-1)

In [None]:
# tf.math.unsorted_segment_mean(losses, tf.cast(policies, tf.int64), 2)

# Original

In [None]:
mean_pol = tf.where(tf.equal(tf.tile(policies,(1,1,losses.shape[-1])),unique_pols[0]), losses, tf.zeros_like(losses))
mean_pol

In [None]:
tf.reduce_mean(mean_pol, axis=-1)

In [None]:
tf.reduce_mean(tf.reduce_mean(mean_pol, axis=-1), axis=-1)

In [None]:
mask = tf.tile(tf.equal(policies, unique_pols[0]), (1,1,3))
mask

In [None]:
tf.boolean_mask(losses, mask)

In [None]:
mask = tf.tile(tf.equal(policies[0,:,:], unique_pols[0]), (1,losses[0,:,:].shape[-1]))
mask

In [None]:
tf.tile(tf.equal(policies[0,:,:], unique_pols[0]), (1,losses[0,:,:].shape[-1]))

In [None]:
x = tf.concat((losses, policies), axis=-1)[0,:,:]
x

In [None]:
data, pol = x[:,:-1], x[:,-1:]
data, pol

In [None]:
mask = tf.tile(tf.equal(pol, unique_pols[0]), (1,data.shape[-1]))
tf.reduce_mean(tf.boolean_mask(data, mask))

In [None]:
tf.concat((losses, policies), axis=-1)

In [None]:
mask = tf.equal(policies, unique_pols[0])
tf.boolean_mask(losses, mask)

In [None]:
tf.ragged.boolean_mask(losses, tf.tile(tf.equal(policies, unique_pols[1]),(1,1,3)))

In [None]:
def simple_ret(x):
    return x[:,-1]

In [None]:
tf.map_fn(simple_ret, tf.concat((mean, policies), axis=-1))

In [None]:
logits = tf.constant([[[10.0, 10.0, 20.0, 20.0],
                      [11.0, 10.0, 10.0, 30.0],
                      [12.0, 10.0, 10.0, 20.0],
                      [13.0, 10.0, 10.0, 20.0]],
                     [[14.0, 11.0, 21.0, 31.0],
                      [15.0, 11.0, 11.0, 21.0],
                      [16.0, 11.0, 11.0, 21.0],
                      [17.0, 11.0, 11.0, 21.0]]])

indices = tf.constant([[[0, 0], [0, 1]], [[1, 1], [1, 3]]])

result = tf.gather_nd(logits, indices)
result

In [None]:
logits.shape, indices.shape

In [None]:
mean

In [None]:
tf.slice(mean, tf.cast(policies, tf.int64))