In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('..')

In [3]:
import tensorflow as tf
import numpy as np
from build_model import interaction_model, SelectInputs

In [4]:
from path_explain import utils
utils.set_up_environment(visible_devices='1')

In [5]:
model = interaction_model(num_features=5,
                          num_layers=2,
                          hidden_layer_size=8,
                          num_outputs=1,
                          activation_function=tf.keras.activations.relu,
                          interactions_to_ignore=None)

In [6]:
model.layers[20].get_config().get('activation', None)

'relu'

In [7]:
for layer in model.layers:
    if layer.get_config().get('activation', None) == 'relu':
        print(layer)

<tensorflow.python.keras.layers.core.Dense object at 0x7efd45b5fbd0>
<tensorflow.python.keras.layers.core.Dense object at 0x7efd3004e610>
<tensorflow.python.keras.layers.core.Dense object at 0x7efd3003f350>
<tensorflow.python.keras.layers.core.Dense object at 0x7efcec0b5a90>
<tensorflow.python.keras.layers.core.Dense object at 0x7efcec04d110>
<tensorflow.python.keras.layers.core.Dense object at 0x7efcec06c750>
<tensorflow.python.keras.layers.core.Dense object at 0x7efcd45c1590>
<tensorflow.python.keras.layers.core.Dense object at 0x7efcd45dc110>
<tensorflow.python.keras.layers.core.Dense object at 0x7efcd45f6510>
<tensorflow.python.keras.layers.core.Dense object at 0x7efcd4592dd0>
<tensorflow.python.keras.layers.core.Dense object at 0x7efcd45ab110>
<tensorflow.python.keras.layers.core.Dense object at 0x7efcd4546e10>
<tensorflow.python.keras.layers.core.Dense object at 0x7efcd4560950>
<tensorflow.python.keras.layers.core.Dense object at 0x7efcd44fca10>
<tensorflow.python.keras.layers.co

In [8]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input (InputLayer)              [(None, 5)]          0                                            
__________________________________________________________________________________________________
select_0 (SelectInputs)         (None, 1)            5           input[0][0]                      
__________________________________________________________________________________________________
select_1 (SelectInputs)         (None, 1)            5           input[0][0]                      
__________________________________________________________________________________________________
select_2 (SelectInputs)         (None, 1)            5           input[0][0]                      
______________________________________________________________________________________________

In [9]:
input_tensor = model.get_layer('select_0_1').input
subnetwork = tf.keras.models.Model(inputs=input_tensor,
                                   outputs=model.get_layer('output_0_1').output)
output_tensor = input_tensor
for layer in subnetwork.layers[1:]:
    output_tensor = layer(output_tensor)

weight_multiply_index = [layer.name.split('/')[0] for layer in model.get_layer('concat').input].index('output_0_1')
final_weighting = model.get_layer('output_final').weights[0][weight_multiply_index, :]
final_weighting = tf.expand_dims(final_weighting, axis=0)
final_weighting = tf.expand_dims(final_weighting, axis=0)

output_tensor = tf.keras.layers.Dense(units=1,
                                      activation=None,
                                      use_bias=False,
                                      weights=final_weighting.numpy(),
                                      trainable=False,
                                      name='subnetwork_output_final')(output_tensor)
subnetwork = tf.keras.models.Model(inputs=input_tensor,
                                   outputs=output_tensor)

In [10]:
model.get_layer('select_0_1').input

<tf.Tensor 'input:0' shape=(None, 5) dtype=float32>

In [11]:
tf.keras.layers.Input

<function tensorflow.python.keras.engine.input_layer.Input(shape=None, batch_size=None, name=None, dtype=None, sparse=False, tensor=None, ragged=False, **kwargs)>

In [12]:
subnetwork.summary()

Model: "model_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input (InputLayer)           [(None, 5)]               0         
_________________________________________________________________
select_0_1 (SelectInputs)    (None, 2)                 10        
_________________________________________________________________
dense_0_0_1 (Dense)          (None, 8)                 24        
_________________________________________________________________
dense_1_0_1 (Dense)          (None, 8)                 72        
_________________________________________________________________
output_0_1 (Dense)           (None, 1)                 9         
_________________________________________________________________
subnetwork_output_final (Den (None, 1)                 1         
Total params: 116
Trainable params: 105
Non-trainable params: 11
____________________________________________________________

In [13]:
subnetwork.layers[0]

<tensorflow.python.keras.engine.input_layer.InputLayer at 0x7efd45626090>

In [14]:
fixed_values = np.random.randn(10, 5).astype(np.float32)

In [15]:
subnetwork(fixed_values)

<tf.Tensor: shape=(10, 1), dtype=float32, numpy=
array([[-0.02686454],
       [-0.06490623],
       [ 0.08229259],
       [ 0.01626243],
       [-0.04495191],
       [ 0.0800527 ],
       [ 0.02393124],
       [-0.13535072],
       [-0.0110657 ],
       [ 0.0556884 ]], dtype=float32)>

In [16]:
test_model = tf.keras.models.Model(inputs=model.input, outputs=model.get_layer('output_0_1').output)

In [17]:
test_model(fixed_values) * final_weighting

<tf.Tensor: shape=(1, 10, 1), dtype=float32, numpy=
array([[[-0.02686454],
        [-0.06490623],
        [ 0.08229259],
        [ 0.01626243],
        [-0.04495191],
        [ 0.0800527 ],
        [ 0.02393124],
        [-0.13535072],
        [-0.0110657 ],
        [ 0.0556884 ]]], dtype=float32)>

In [18]:
from neural_interaction_detection import NeuralInteractionDetectionExplainerTF

In [19]:
from contextual_decomposition import ContextualDecompositionExplainerTF

In [20]:
nid_explainer = NeuralInteractionDetectionExplainerTF(subnetwork)

In [21]:
interactions = nid_explainer.interactions()

In [22]:
interactions

array([[0.        , 0.32805729],
       [0.32805729, 0.        ]])

In [23]:
cd_explainer = ContextualDecompositionExplainerTF(subnetwork)

In [24]:
attributions, _ = cd_explainer.attributions(fixed_values, 10)

In [25]:
interactions, _ = cd_explainer.interactions(fixed_values, 10)
interactions = interactions - attributions[:, np.newaxis, :] - attributions[:, :, np.newaxis]

In [26]:
interactions[0]

array([[ 0.03002235, -0.00227935,  0.        ,  0.        ,  0.        ],
       [-0.00227935,  0.01914803,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ]])