The purpose of this script was to work out how best to implement REx in TensorFlow.

Each mini-batch of data will contain an unknown number of records for each policy. Some policies may not be represented at all.

It was important to correctly determine the loss for each policy, and the variance across the policies.

In [1]:
import numpy as np
import tensorflow as tf
# tf.compat.v1.disable_eager_execution()
tf.compat.v1.enable_eager_execution()

In [2]:
sess = tf.compat.v1.Session()

# Extract Records based on the Policy they belong to

This would be applied before the data is split among batches.

In [3]:
# Creation of mock data
data = np.reshape(np.arange(12, dtype=float), (4,3))
data

array([[ 0.,  1.,  2.],
       [ 3.,  4.,  5.],
       [ 6.,  7.,  8.],
       [ 9., 10., 11.]])

In [4]:
# Assign every other record to policy 0, and the remainder to policy 1
policies = np.array([0.,1.,0.,1.])[:, None]
policies

array([[0.],
       [1.],
       [0.],
       [1.]])

Method 1 for extracting records based on their policy.

In [5]:
data[np.squeeze(np.argwhere(np.squeeze(policies)==0.)), :]

array([[0., 1., 2.],
       [6., 7., 8.]])

In [6]:
data[np.squeeze(np.argwhere(np.squeeze(policies)==1.)), :]

array([[ 3.,  4.,  5.],
       [ 9., 10., 11.]])

Method 2 for extracting records based on their policy.

In [7]:
tf.gather(data, tf.squeeze(tf.where(tf.squeeze(policies)==0.)))

W0302 09:45:42.794969 47371650098048 deprecation.py:323] From <ipython-input-7-1ae502cbc724>:1: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


InvalidArgumentError: WhereOp : Unhandled input dimensions: 0 [Op:Where] name: Where/

In [8]:
tf.gather(data, tf.squeeze(tf.where(tf.squeeze(policies)==1.)))

InvalidArgumentError: WhereOp : Unhandled input dimensions: 0 [Op:Where] name: Where/

# Setup

In [9]:
rex_beta = 0.1

The below simulates the MSEs for an ensemble of 4 models (`M`), each passed 3 observerations (`N`) with dimensionality of 2 (`D`). The matrix thus has dimensionality `MxNxD`, or `4x3x2` in this case.

This is what would have been produced in the original code for the MSE - an error for each dimension in each record for each model.

Remember that each model will recieve a different mini-batch of data, and so the policy of a record at a given index can/will vary across the models.

In [10]:
losses_arr = np.reshape(np.arange(24, dtype=float), (4,3,2))
losses = tf.constant(losses_arr, dtype=float)
losses#.eval(session=sess)

<tf.Tensor: id=7, shape=(4, 3, 2), dtype=float32, numpy=
array([[[ 0.,  1.],
        [ 2.,  3.],
        [ 4.,  5.]],

       [[ 6.,  7.],
        [ 8.,  9.],
        [10., 11.]],

       [[12., 13.],
        [14., 15.],
        [16., 17.]],

       [[18., 19.],
        [20., 21.],
        [22., 23.]]], dtype=float32)>

Assigning policies to each record fed to each model. There are 3 policies (`P`), and so while each model could recieve an observation from every policy, we've intentionally ensured this is not the case to capture cases where this arises in reality.

The policies are stored in an `MxNx1` matrix - which would be extracted from the original data passed to the model.

In [11]:
policies = tf.constant([
    [
        [1.],
        [1.],
        [2.],
    ],
    [
        [1.],
        [0.],
        [0.],
    ],
    [
        [0.],
        [2.],
        [2.],
    ],
    [
        [0.],
        [1.],
        [1.],
    ]

])
policies#.eval(session=sess)

<tf.Tensor: id=9, shape=(4, 3, 1), dtype=float32, numpy=
array([[[1.],
        [1.],
        [2.]],

       [[1.],
        [0.],
        [0.]],

       [[0.],
        [2.],
        [2.]],

       [[0.],
        [1.],
        [1.]]], dtype=float32)>

Take the mean across the number of dimensions - this is what the MOPO code does, rather than take the vector norm.

In the below form, the observation losses could alternatively be the log-likelihood.

In [12]:
obs_losses = tf.reduce_mean(losses, axis=-1, keepdims=True)
obs_losses#.eval(session=sess)

<tf.Tensor: id=12, shape=(4, 3, 1), dtype=float32, numpy=
array([[[ 0.5],
        [ 2.5],
        [ 4.5]],

       [[ 6.5],
        [ 8.5],
        [10.5]],

       [[12.5],
        [14.5],
        [16.5]],

       [[18.5],
        [20.5],
        [22.5]]], dtype=float32)>

# Solution

This is the solution implemented in `mopo/models/bnn.py`. It was tested under a number of different scenarios to capture edge cases.

Integers are used to identify policies - make sure that the data type is always correct by explicitly casting.

In [13]:
policies = tf.cast(policies, tf.int32)
policies#.eval(session=sess)

<tf.Tensor: id=14, shape=(4, 3, 1), dtype=int32, numpy=
array([[[1],
        [1],
        [2]],

       [[1],
        [0],
        [0]],

       [[0],
        [2],
        [2]],

       [[0],
        [1],
        [1]]], dtype=int32)>

In [14]:
unique_pols = tf.unique(tf.reshape(policies, [-1])).y
unique_pols#.eval(session=sess)

<tf.Tensor: id=18, shape=(3,), dtype=int32, numpy=array([1, 2, 0], dtype=int32)>

Identify the policy with the largest integer in the mini-batch. Note that there is no requirement that records be present for all policies.

For instance, we may have records for policies [0, 1, 4] in the current mini-batch. The highest policy integer is therefore 4, and it does not matter that we do not have records for policy 3.

Similarly, it may be that when looking at the entire dataset the highest policy integer is actually 5 - it does not matter if a mini-batch has no records for this policy.

In [15]:
tf.reduce_max(unique_pols+1)

<tf.Tensor: id=24, shape=(), dtype=int32, numpy=3>

Create a one-hot encoded matrix which identifies the policy each record in the minibatch belongs to. This has dimension `MxNxP`.

In [16]:
pol_one_hot = tf.squeeze(tf.one_hot(policies, tf.reduce_max(unique_pols+1), axis=-1), axis=-2)
pol_one_hot#.eval(session=sess)

<tf.Tensor: id=33, shape=(4, 3, 3), dtype=float32, numpy=
array([[[0., 1., 0.],
        [0., 1., 0.],
        [0., 0., 1.]],

       [[0., 1., 0.],
        [1., 0., 0.],
        [1., 0., 0.]],

       [[1., 0., 0.],
        [0., 0., 1.],
        [0., 0., 1.]],

       [[1., 0., 0.],
        [0., 1., 0.],
        [0., 1., 0.]]], dtype=float32)>

Use the one-hot matrix to sum the losses for each policy - this has dimensions `MxP`.

In [17]:
pol_mean_sum = tf.squeeze(tf.matmul(tf.transpose(pol_one_hot, [0,2,1]), obs_losses), axis=-1)
pol_mean_sum#.eval(session=sess)

<tf.Tensor: id=38, shape=(4, 3), dtype=float32, numpy=
array([[ 0. ,  3. ,  4.5],
       [19. ,  6.5,  0. ],
       [12.5,  0. , 31. ],
       [18.5, 43. ,  0. ]], dtype=float32)>

Identify the number of records present for each policy. Remember that we'd intentially designed the dataset so that each model received a mini-batch with no records for one policy.

This has dimensions `MxP`.

In [18]:
pol_count = tf.reduce_sum(pol_one_hot, axis=-2)
pol_count#.eval(session=sess)

<tf.Tensor: id=41, shape=(4, 3), dtype=float32, numpy=
array([[0., 2., 1.],
       [2., 1., 0.],
       [1., 0., 2.],
       [1., 2., 0.]], dtype=float32)>

Determine the mean loss for each policy. Use the `no_nan` method so that we do not get divide by zero errors (given that each model has no records for one policy).

This again has dimensions `MxP`.

In [19]:
policy_losses = tf.math.divide_no_nan(pol_mean_sum, pol_count)
policy_losses#.eval(session=sess)

<tf.Tensor: id=43, shape=(4, 3), dtype=float32, numpy=
array([[ 0. ,  1.5,  4.5],
       [ 9.5,  6.5,  0. ],
       [12.5,  0. , 15.5],
       [18.5, 21.5,  0. ]], dtype=float32)>

Determine the mean loss for each policy across the models - resulting in a matrix with dimension `P`.

NOTE: This was calculated solely for information purposes, to track how the loss for each policy changed during training.

In [20]:
mean_policy_losses = tf.reduce_mean(policy_losses, axis=0)
mean_policy_losses#.eval(session=sess)

<tf.Tensor: id=46, shape=(3,), dtype=float32, numpy=array([10.125,  7.375,  5.   ], dtype=float32)>

Sum the policy losses for each model - resulting in a matrix with dimension `M`.

In [21]:
policy_total_losses = tf.reduce_sum(policy_losses, axis=-1)
policy_total_losses#.eval(session=sess)

<tf.Tensor: id=49, shape=(4,), dtype=float32, numpy=array([ 6., 16., 28., 40.], dtype=float32)>

Manually determine what the variances should be.

In [22]:
np.var(np.array((1.5,4.5))), np.var(np.array((9.5,6.5))), np.var(np.array((12.5,15.5))), np.var(np.array((18.5,21.5)))

(2.25, 2.25, 2.25, 2.25)

Two methods of determining the variance for each model are shown below - both achieve the same thing.

In [33]:
policy_losses

<tf.Tensor: id=43, shape=(4, 3), dtype=float32, numpy=
array([[ 0. ,  1.5,  4.5],
       [ 9.5,  6.5,  0. ],
       [12.5,  0. , 15.5],
       [18.5, 21.5,  0. ]], dtype=float32)>

In [36]:
tf.math.reduce_variance(tf.boolean_mask(policy_losses, pol_count>0.), axis=-1)#.eval(session=sess)

<tf.Tensor: id=587, shape=(), dtype=float32, numpy=42.9375>

In [24]:
def determine_var(x):
    batch_pol_losses, batch_pol_counts = x[0,:], x[1,:]
    return tf.math.reduce_variance(tf.boolean_mask(batch_pol_losses, batch_pol_counts>0.))

In [25]:
policy_var_losses = tf.map_fn(determine_var, tf.stack((policy_losses, pol_count), axis=-2))
policy_var_losses#.eval(session=sess)

<tf.Tensor: id=311, shape=(4,), dtype=float32, numpy=array([2.25, 2.25, 2.25, 2.25], dtype=float32)>

In [26]:
total_loss_var = policy_total_losses + policy_var_losses
total_loss_var#.eval(session=sess)

<tf.Tensor: id=313, shape=(4,), dtype=float32, numpy=array([ 8.25, 18.25, 30.25, 42.25], dtype=float32)>

In [46]:
def determine_abs_var(x):
    batch_pol_losses, batch_pol_counts = x[0,:], x[1,:]
    print('batch_pol_losses', batch_pol_losses)
    mean = tf.math.reduce_mean(tf.boolean_mask(batch_pol_losses, batch_pol_counts>0.))
    batch_pol_losses_var_abs = tf.math.abs(batch_pol_losses - mean)
    return tf.math.reduce_mean(tf.boolean_mask(batch_pol_losses_var_abs, batch_pol_counts>0.))

In [28]:
policy_losses

<tf.Tensor: id=43, shape=(4, 3), dtype=float32, numpy=
array([[ 0. ,  1.5,  4.5],
       [ 9.5,  6.5,  0. ],
       [12.5,  0. , 15.5],
       [18.5, 21.5,  0. ]], dtype=float32)>

In [40]:
mean = tf.math.reduce_mean(tf.boolean_mask(policy_losses, pol_count>0.))
mean

<tf.Tensor: id=712, shape=(), dtype=float32, numpy=11.25>

In [42]:
policy_losses_var_abs = tf.math.abs(policy_losses - mean)
policy_losses_var_abs

<tf.Tensor: id=717, shape=(4, 3), dtype=float32, numpy=
array([[11.25,  9.75,  6.75],
       [ 1.75,  4.75, 11.25],
       [ 1.25, 11.25,  4.25],
       [ 7.25, 10.25, 11.25]], dtype=float32)>

In [47]:
policy_var_losses = tf.map_fn(determine_abs_var, tf.stack((policy_losses, pol_count), axis=-2))
policy_var_losses#.eval(session=sess)

batch_pol_losses tf.Tensor([0.  1.5 4.5], shape=(3,), dtype=float32)
batch_pol_losses tf.Tensor([9.5 6.5 0. ], shape=(3,), dtype=float32)
batch_pol_losses tf.Tensor([12.5  0.  15.5], shape=(3,), dtype=float32)
batch_pol_losses tf.Tensor([18.5 21.5  0. ], shape=(3,), dtype=float32)


<tf.Tensor: id=1405, shape=(4,), dtype=float32, numpy=array([1.5, 1.5, 1.5, 1.5], dtype=float32)>

# IRM

This is not a complete method.

In [27]:
mean_arr = np.reshape(np.arange(24, dtype=float), (4,3,2))
mean = tf.constant(losses_arr, dtype=float)
log_var = tf.identity(mean)
mean, log_var

(<tf.Tensor: shape=(4, 3, 2), dtype=float32, numpy=
 array([[[ 0.,  1.],
         [ 2.,  3.],
         [ 4.,  5.]],
 
        [[ 6.,  7.],
         [ 8.,  9.],
         [10., 11.]],
 
        [[12., 13.],
         [14., 15.],
         [16., 17.]],
 
        [[18., 19.],
         [20., 21.],
         [22., 23.]]], dtype=float32)>,
 <tf.Tensor: shape=(4, 3, 2), dtype=float32, numpy=
 array([[[ 0.,  1.],
         [ 2.,  3.],
         [ 4.,  5.]],
 
        [[ 6.,  7.],
         [ 8.,  9.],
         [10., 11.]],
 
        [[12., 13.],
         [14., 15.],
         [16., 17.]],
 
        [[18., 19.],
         [20., 21.],
         [22., 23.]]], dtype=float32)>)

In [28]:
policies = tf.constant([
    [
        [1.],
        [1.],
        [2.],
    ],
    [
        [1.],
        [0.],
        [0.],
    ],
    [
        [0.],
        [2.],
        [2.],
    ],
    [
        [0.],
        [1.],
        [1.],
    ]

])
policies#.eval(session=sess)

<tf.Tensor: shape=(4, 3, 1), dtype=float32, numpy=
array([[[1.],
        [1.],
        [2.]],

       [[1.],
        [0.],
        [0.]],

       [[0.],
        [2.],
        [2.]],

       [[0.],
        [1.],
        [1.]]], dtype=float32)>

In [29]:
policies = tf.cast(policies, tf.int32)
policies#.eval(session=sess)

<tf.Tensor: shape=(4, 3, 1), dtype=int32, numpy=
array([[[1],
        [1],
        [2]],

       [[1],
        [0],
        [0]],

       [[0],
        [2],
        [2]],

       [[0],
        [1],
        [1]]], dtype=int32)>

In [30]:
unique_pols = tf.unique(tf.reshape(policies, [-1])).y
unique_pols#.eval(session=sess)

<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 0], dtype=int32)>

In [31]:
pol_one_hot = tf.squeeze(tf.one_hot(policies, tf.reduce_max(unique_pols+1), axis=-1), axis=-2)
pol_one_hot#.eval(session=sess)

<tf.Tensor: shape=(4, 3, 3), dtype=float32, numpy=
array([[[0., 1., 0.],
        [0., 1., 0.],
        [0., 0., 1.]],

       [[0., 1., 0.],
        [1., 0., 0.],
        [1., 0., 0.]],

       [[1., 0., 0.],
        [0., 0., 1.],
        [0., 0., 1.]],

       [[1., 0., 0.],
        [0., 1., 0.],
        [0., 1., 0.]]], dtype=float32)>

In [32]:
pol_count = tf.reduce_sum(pol_one_hot, axis=-2)
pol_count#.eval(session=sess)

<tf.Tensor: shape=(4, 3), dtype=float32, numpy=
array([[0., 2., 1.],
       [2., 1., 0.],
       [1., 0., 2.],
       [1., 2., 0.]], dtype=float32)>

## Simple IRM Example

In [33]:
test_input =  tf.constant(np.reshape(np.arange(4, dtype=float), (1,4)))
# test_w = tf.Variable(np.ones((1)), trainable=True)
test_w = tf.Variable(np.ones_like(test_input), trainable=True)
test_input, test_w, test_input*test_w

(<tf.Tensor: shape=(1, 4), dtype=float64, numpy=array([[0., 1., 2., 3.]])>,
 <tf.Variable 'Variable:0' shape=(1, 4) dtype=float64, numpy=array([[1., 1., 1., 1.]])>,
 <tf.Tensor: shape=(1, 4), dtype=float64, numpy=array([[0., 1., 2., 3.]])>)

In [34]:
with tf.GradientTape(persistent=True) as tape:
    result = (test_input*test_w)**2
tape.gradient(result, test_w)

<tf.Tensor: shape=(1, 4), dtype=float64, numpy=array([[ 0.,  2.,  8., 18.]])>

## Simple Example - Same Form as Solution

In [35]:
mean_dummy_w = tf.Variable(np.ones_like(mean), trainable=True, dtype=float)
log_var_dummy_w = tf.Variable(np.ones_like(log_var), trainable=True, dtype=float)

In [36]:
with tf.GradientTape(persistent=True) as tape_mean:
    with tf.GradientTape(persistent=True) as tape_log_var:
        mean_w = mean * mean_dummy_w
        obs_losses = tf.reduce_sum(mean_w**2, axis=-1, keepdims=True)

mean_dummy_grads = tape_mean.gradient(obs_losses, mean_dummy_w)
log_var_dummy_grads =  tape_log_var.gradient(obs_losses, log_var_dummy_w)
mean_dummy_grads#, log_var_dummy_grads

<tf.Tensor: shape=(4, 3, 2), dtype=float32, numpy=
array([[[   0.,    2.],
        [   8.,   18.],
        [  32.,   50.]],

       [[  72.,   98.],
        [ 128.,  162.],
        [ 200.,  242.]],

       [[ 288.,  338.],
        [ 392.,  450.],
        [ 512.,  578.]],

       [[ 648.,  722.],
        [ 800.,  882.],
        [ 968., 1058.]]], dtype=float32)>

In [37]:
obs_losses

<tf.Tensor: shape=(4, 3, 1), dtype=float32, numpy=
array([[[1.000e+00],
        [1.300e+01],
        [4.100e+01]],

       [[8.500e+01],
        [1.450e+02],
        [2.210e+02]],

       [[3.130e+02],
        [4.210e+02],
        [5.450e+02]],

       [[6.850e+02],
        [8.410e+02],
        [1.013e+03]]], dtype=float32)>

In [38]:
pol_mean_dummy_grad_sum = tf.matmul(tf.transpose(pol_one_hot, [0,2,1]), mean_dummy_grads)
pol_mean_dummy_grad_sum

<tf.Tensor: shape=(4, 3, 2), dtype=float32, numpy=
array([[[   0.,    0.],
        [   8.,   20.],
        [  32.,   50.]],

       [[ 328.,  404.],
        [  72.,   98.],
        [   0.,    0.]],

       [[ 288.,  338.],
        [   0.,    0.],
        [ 904., 1028.]],

       [[ 648.,  722.],
        [1768., 1940.],
        [   0.,    0.]]], dtype=float32)>

In [39]:
tf.reduce_sum(tf.reduce_sum((pol_mean_dummy_grad_sum**2), axis=-1), axis=-1)

<tf.Tensor: shape=(4,), dtype=float32, numpy=
array([3.988000e+03, 2.855880e+05, 2.071188e+06, 7.830612e+06],
      dtype=float32)>

In [40]:
with tf.GradientTape() as tape_mean:
    with tf.GradientTape() as tape_log_var:
        mean_w = mean * mean_dummy_w
        log_var_w = log_var * log_var_dummy_w
        inv_var_w = tf.exp(-log_var_w)
        obs_mse_losses = tf.reduce_mean(tf.square(mean_w - (mean+0.1)) * inv_var_w, axis=-1, keepdims=True)
        obs_var_losses = tf.reduce_mean(log_var_w, axis=-1, keepdims=True)
        obs_losses = obs_mse_losses + obs_var_losses
        pol_mean_sum = tf.squeeze(tf.matmul(tf.transpose(pol_one_hot, [0,2,1]), obs_losses), axis=-1)
        policy_losses = tf.math.divide_no_nan(pol_mean_sum, pol_count)

mean_dummy_grads = tape_mean.gradient(policy_losses, mean_dummy_w)
log_var_dummy_grads = tape_log_var.gradient(policy_losses, log_var_dummy_w)

In [41]:
obs_losses

<tf.Tensor: shape=(4, 3, 1), dtype=float32, numpy=
array([[[ 0.5068394],
        [ 2.5009255],
        [ 4.5001254]],

       [[ 6.500017 ],
        [ 8.500002 ],
        [10.5      ]],

       [[12.5      ],
        [14.5      ],
        [16.5      ]],

       [[18.5      ],
        [20.5      ],
        [22.5      ]]], dtype=float32)>

In [42]:
mean_dummy_grads

<tf.Tensor: shape=(4, 3, 2), dtype=float32, numpy=
array([[[-0.0000000e+00, -1.8393977e-02],
        [-1.3533515e-02, -7.4680536e-03],
        [-7.3262486e-03, -3.3689705e-03]],

       [[-1.4872500e-03, -6.3831673e-04],
        [-1.3418557e-04, -5.5534623e-05],
        [-2.2700051e-05, -9.1859702e-06]],

       [[-7.3730830e-06, -2.9384394e-06],
        [-5.8207235e-07, -2.2942763e-07],
        [-9.0028486e-08, -3.5189608e-08]],

       [[-2.7414071e-08, -1.0645354e-08],
        [-2.0611615e-09, -7.9617191e-10],
        [-3.0684266e-10, -1.1801207e-10]]], dtype=float32)>

In [43]:
tf.reduce_sum(tf.reduce_sum(tf.matmul(tf.transpose(pol_one_hot, [0,2,1]), mean_dummy_grads), axis=-1), axis=-1)

<tf.Tensor: shape=(4,), dtype=float32, numpy=
array([-5.0090764e-02, -2.3471732e-03, -1.1248240e-05, -4.1341615e-08],
      dtype=float32)>

In [44]:
tf.reduce_sum(tf.reduce_sum(tf.matmul(tf.transpose(pol_one_hot, [0,2,1]), log_var_dummy_grads), axis=-1), axis=-1)

<tf.Tensor: shape=(4,), dtype=float32, numpy=array([ 5.9974957, 15.999882 , 28.       , 40.       ], dtype=float32)>

In [45]:
tf.squeeze(obs_losses)#.eval(session=sess)

<tf.Tensor: shape=(4, 3), dtype=float32, numpy=
array([[ 0.5068394,  2.5009255,  4.5001254],
       [ 6.500017 ,  8.500002 , 10.5      ],
       [12.5      , 14.5      , 16.5      ],
       [18.5      , 20.5      , 22.5      ]], dtype=float32)>

In [46]:
# tf.map_fn(lambda x: tf.gradients(x, mean_dummy_w)[0], tf.squeeze(obs_losses, axis=-1)).eval(session=sess)

In [47]:
# tf.gradients(obs_losses, mean_dummy_w)[0].eval(session=sess), tf.gradients(obs_losses, log_var_dummy_w)[0].eval(session=sess)