In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('..')

In [3]:
import tensorflow as tf
import numpy as np
from build_model import interaction_model

In [4]:
from path_explain import utils
utils.set_up_environment(visible_devices='1')

In [5]:
model = interaction_model(num_features=5,
                          num_layers=2,
                          hidden_layer_size=8,
                          num_outputs=1,
                          activation_function=tf.keras.activations.relu,
                          interactions_to_ignore=None)

In [6]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input (InputLayer)              [(None, 5)]          0                                            
__________________________________________________________________________________________________
select_0 (Dense)                (None, 1)            5           input[0][0]                      
__________________________________________________________________________________________________
select_1 (Dense)                (None, 1)            5           input[0][0]                      
__________________________________________________________________________________________________
select_2 (Dense)                (None, 1)            5           input[0][0]                      
______________________________________________________________________________________________

In [7]:
model.layers[20].get_config().get('activation', None)

'relu'

In [8]:
for layer in model.layers:
    if layer.get_config().get('activation', None) == 'relu':
        print(layer.name)

dense_0_0
dense_0_1
dense_0_2
dense_0_3
dense_0_4
dense_0_0_1
dense_0_0_2
dense_0_0_3
dense_0_0_4
dense_0_1_2
dense_0_1_3
dense_0_1_4
dense_0_2_3
dense_0_2_4
dense_0_3_4
dense_1_0
dense_1_1
dense_1_2
dense_1_3
dense_1_4
dense_1_0_1
dense_1_0_2
dense_1_0_3
dense_1_0_4
dense_1_1_2
dense_1_1_3
dense_1_1_4
dense_1_2_3
dense_1_2_4
dense_1_3_4


In [9]:
model.layers[0].trainable

True

In [10]:
input_tensor = model.get_layer('select_0_1').input
subnetwork = tf.keras.models.Model(inputs=input_tensor,
                                   outputs=model.get_layer('output_0_1').output)
output_tensor = input_tensor
for layer in subnetwork.layers[1:]:
    output_tensor = layer(output_tensor)

weight_multiply_index = [layer.name.split('/')[0] for layer in model.get_layer('concat').input].index('output_0_1')
final_weighting = model.get_layer('output_final').weights[0][weight_multiply_index, :]
final_weighting = tf.expand_dims(final_weighting, axis=0)
final_weighting = tf.expand_dims(final_weighting, axis=0)

output_tensor = tf.keras.layers.Dense(units=1,
                                      activation=None,
                                      use_bias=False,
                                      weights=final_weighting.numpy(),
                                      trainable=False,
                                      name='subnetwork_output_final')(output_tensor)
subnetwork = tf.keras.models.Model(inputs=input_tensor,
                                   outputs=output_tensor)

In [11]:
model.get_layer('select_0_1').input

<tf.Tensor 'input:0' shape=(None, 5) dtype=float32>

In [12]:
tf.keras.layers.Input

<function tensorflow.python.keras.engine.input_layer.Input(shape=None, batch_size=None, name=None, dtype=None, sparse=False, tensor=None, ragged=False, **kwargs)>

In [13]:
subnetwork.summary()

Model: "model_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input (InputLayer)           [(None, 5)]               0         
_________________________________________________________________
select_0_1 (Dense)           (None, 2)                 10        
_________________________________________________________________
dense_0_0_1 (Dense)          (None, 8)                 24        
_________________________________________________________________
dense_1_0_1 (Dense)          (None, 8)                 72        
_________________________________________________________________
output_0_1 (Dense)           (None, 1)                 9         
_________________________________________________________________
subnetwork_output_final (Den (None, 1)                 1         
Total params: 116
Trainable params: 105
Non-trainable params: 11
____________________________________________________________

In [14]:
subnetwork.layers[0]

<tensorflow.python.keras.engine.input_layer.InputLayer at 0x7f9109cebbd0>

In [15]:
fixed_values = np.random.randn(10, 5).astype(np.float32)

In [16]:
subnetwork(fixed_values)

<tf.Tensor: shape=(10, 1), dtype=float32, numpy=
array([[0.        ],
       [0.07608242],
       [0.7607951 ],
       [0.16543923],
       [0.        ],
       [0.4151308 ],
       [0.        ],
       [0.27902663],
       [0.66558653],
       [0.23474546]], dtype=float32)>

In [17]:
test_model = tf.keras.models.Model(inputs=model.input, outputs=model.get_layer('output_0_1').output)

In [18]:
test_model(fixed_values) * final_weighting

<tf.Tensor: shape=(1, 10, 1), dtype=float32, numpy=
array([[[0.        ],
        [0.07608242],
        [0.7607951 ],
        [0.16543923],
        [0.        ],
        [0.4151308 ],
        [0.        ],
        [0.27902663],
        [0.66558653],
        [0.23474546]]], dtype=float32)>

In [19]:
from neural_interaction_detection import NeuralInteractionDetectionExplainerTF

In [20]:
from contextual_decomposition import ContextualDecompositionExplainerTF

In [21]:
nid_explainer = NeuralInteractionDetectionExplainerTF(subnetwork)

In [22]:
interactions = nid_explainer.interactions()

In [23]:
interactions

array([[0.        , 1.74311638],
       [1.74311638, 0.        ]])

In [24]:
cd_explainer = ContextualDecompositionExplainerTF(subnetwork)

In [25]:
attributions, _ = cd_explainer.attributions(fixed_values, 10)

In [26]:
interactions, _ = cd_explainer.interactions(fixed_values, 10)
interactions = interactions - attributions[:, np.newaxis, :] - attributions[:, :, np.newaxis]

In [27]:
interactions[0]

array([[-0.04032952, -0.03299661,  0.        ,  0.        ,  0.        ],
       [-0.03299661, -0.02566369,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ]])