In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('..')

In [3]:
import tensorflow as tf
import numpy as np
from build_model import interaction_model

In [4]:
from path_explain import utils
utils.set_up_environment(visible_devices='1')

In [5]:
model = interaction_model(num_features=5,
                          num_layers=2,
                          hidden_layer_size=8,
                          num_outputs=1,
                          activation_function=tf.keras.activations.relu,
                          interactions_to_ignore=None)

In [6]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input (InputLayer)              [(None, 5)]          0                                            
__________________________________________________________________________________________________
select_0 (Dense)                (None, 1)            5           input[0][0]                      
__________________________________________________________________________________________________
select_1 (Dense)                (None, 1)            5           input[0][0]                      
__________________________________________________________________________________________________
select_2 (Dense)                (None, 1)            5           input[0][0]                      
______________________________________________________________________________________________

In [7]:
model.layers[20].get_config().get('activation', None)

'relu'

In [8]:
for layer in model.layers:
    if layer.get_config().get('activation', None) == 'relu':
        print(layer.name)

dense_0_0
dense_0_1
dense_0_2
dense_0_3
dense_0_4
dense_0_0_1
dense_0_0_2
dense_0_0_3
dense_0_0_4
dense_0_1_2
dense_0_1_3
dense_0_1_4
dense_0_2_3
dense_0_2_4
dense_0_3_4
dense_1_0
dense_1_1
dense_1_2
dense_1_3
dense_1_4
dense_1_0_1
dense_1_0_2
dense_1_0_3
dense_1_0_4
dense_1_1_2
dense_1_1_3
dense_1_1_4
dense_1_2_3
dense_1_2_4
dense_1_3_4


In [54]:
model.get_layer('output_0_1')._outbound_nodes[0].outbound_layer.name

'concat'

In [37]:
model.get_layer('dense_0_0_1')._inbound_nodes[0].inbound_layers.name

AttributeError: 'Dense' object has no attribute '_output_nodes'

In [21]:
model.get_layer('output_0_1').output

<tf.Tensor 'output_0_1/Identity:0' shape=(None, 1) dtype=float32>

In [43]:
model.get_layer('dense_1_0_1').input

<tf.Tensor 'dense_0_0_1/Identity:0' shape=(None, 8) dtype=float32>

In [50]:
input_tensor = tf.keras.layers.Input(shape=2)

In [64]:
input_tensor  = tf.keras.layers.Input(shape=2)
output_tensor = input_tensor
layer = model.get_layer('dense_0_0_1')
while 'concat' not in layer.name:
    output_tensor = layer(output_tensor)
    next_layer_name = layer._outbound_nodes[0].outbound_layer.name
    layer = model.get_layer(next_layer_name)

weight_multiply_index = [layer.name.split('/')[0] for layer in model.get_layer('concat').input].index('output_0_1')
final_weighting = model.get_layer('output_final').weights[0][weight_multiply_index, :]
final_weighting = tf.expand_dims(final_weighting, axis=0)
final_weighting = tf.expand_dims(final_weighting, axis=0)

output_tensor = tf.keras.layers.Dense(units=1,
                                      activation=None,
                                      use_bias=False,
                                      weights=final_weighting.numpy(),
                                      trainable=False,
                                      name='subnetwork_output_final')(output_tensor)
subnetwork = tf.keras.models.Model(inputs=input_tensor,
                                   outputs=output_tensor)

In [66]:
subnetwork.summary()

Model: "model_11"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_6 (InputLayer)         [(None, 2)]               0         
_________________________________________________________________
dense_0_0_1 (Dense)          (None, 8)                 24        
_________________________________________________________________
dense_1_0_1 (Dense)          (None, 8)                 72        
_________________________________________________________________
output_0_1 (Dense)           (None, 1)                 9         
_________________________________________________________________
subnetwork_output_final (Den (None, 1)                 1         
Total params: 106
Trainable params: 105
Non-trainable params: 1
_________________________________________________________________


In [67]:
fixed_values = np.random.randn(10, 5).astype(np.float32)

In [69]:
subnetwork(fixed_values[:, 0:2])

<tf.Tensor: shape=(10, 1), dtype=float32, numpy=
array([[-0.01589721],
       [-0.00501881],
       [-0.01682438],
       [-0.00513313],
       [-0.00910988],
       [-0.03814395],
       [-0.01566765],
       [-0.00903523],
       [-0.04518983],
       [-0.06093765]], dtype=float32)>

In [70]:
test_model = tf.keras.models.Model(inputs=model.input, outputs=model.get_layer('output_0_1').output)

In [71]:
test_model(fixed_values) * final_weighting

<tf.Tensor: shape=(1, 10, 1), dtype=float32, numpy=
array([[[-0.01589721],
        [-0.00501881],
        [-0.01682438],
        [-0.00513313],
        [-0.00910988],
        [-0.03814395],
        [-0.01566765],
        [-0.00903523],
        [-0.04518983],
        [-0.06093765]]], dtype=float32)>

In [73]:
from neural_interaction_detection import NeuralInteractionDetectionExplainerTF

In [72]:
from contextual_decomposition import ContextualDecompositionExplainerTF

In [74]:
nid_explainer = NeuralInteractionDetectionExplainerTF(subnetwork)

In [75]:
interactions = nid_explainer.interactions()

In [76]:
interactions

array([[0.        , 1.76005387],
       [1.76005387, 0.        ]])

In [77]:
cd_explainer = ContextualDecompositionExplainerTF(subnetwork)

In [78]:
attributions, _ = cd_explainer.attributions(fixed_values[:, 0:2], 10)

In [79]:
interactions, _ = cd_explainer.interactions(fixed_values[:, 0:2], 10)
interactions = interactions - attributions[:, np.newaxis, :] - attributions[:, :, np.newaxis]

In [80]:
interactions[0]

array([[0.00403177, 0.00012352],
       [0.00012352, 0.02800969]])