/
feature_ablation.py
474 lines (439 loc) · 23.2 KB
/
feature_ablation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
#!/usr/bin/env python3
import torch
from .._utils.common import (
_format_attributions,
_format_input,
_format_input_baseline,
_run_forward,
_expand_additional_forward_args,
_expand_target,
_format_additional_forward_args,
)
from .._utils.attribution import PerturbationAttribution
class FeatureAblation(PerturbationAttribution):
def __init__(self, forward_func):
r"""
Args:
forward_func (callable): The forward function of the model or
any modification of it
"""
PerturbationAttribution.__init__(self, forward_func)
self.use_weights = False
def attribute(
self,
inputs,
baselines=None,
target=None,
additional_forward_args=None,
feature_mask=None,
ablations_per_eval=1,
**kwargs
):
r""""
A perturbation based approach to computing attribution, involving
replacing each input feature with a given baseline / reference, and
computing the difference in output. By default, each scalar value within
each input tensor is taken as a feature and replaced independently. Passing
a feature mask, allows grouping features to be ablated together. This can
be used in cases such as images, where an entire segment or region
can be ablated, measuring the importance of the segment (feature group).
Each input scalar in the group will be given the same attribution value
equal to the change in target as a result of ablating the entire feature
group.
The forward function can either return a scalar per example, or a single
scalar for the full batch. If a single scalar is returned for the batch,
`ablations_per_eval` must be 1, and the returned attributions will have
first dimension 1, corresponding to feature importance across all
examples in the batch.
Args:
inputs (tensor or tuple of tensors): Input for which ablation
attributions are computed. If forward_func takes a single
tensor as input, a single input tensor should be provided.
If forward_func takes multiple tensors as input, a tuple
of the input tensors should be provided. It is assumed
that for all given input tensors, dimension 0 corresponds
to the number of examples (aka batch size), and if
multiple input tensors are provided, the examples must
be aligned appropriately.
baselines (scalar, tensor, tuple of scalars or tensors, optional):
Baselines define reference value which replaces each
feature when ablated.
Baselines can be provided as:
- a single tensor, if inputs is a single tensor, with
exactly the same dimensions as inputs or
broadcastable to match the dimensions of inputs
- a single scalar, if inputs is a single tensor, which will
be broadcasted for each input value in input tensor.
- a tuple of tensors or scalars, the baseline corresponding
to each tensor in the inputs' tuple can be:
- either a tensor with
exactly the same dimensions as inputs or
broadcastable to match the dimensions of inputs
- or a scalar, corresponding to a tensor in the
inputs' tuple. This scalar value is broadcasted
for corresponding input tensor.
In the cases when `baselines` is not provided, we internally
use zero scalar corresponding to each input tensor.
Default: None
target (int, tuple, tensor or list, optional): Output indices for
which difference is computed (for classification cases,
this is usually the target class).
If the network returns a scalar value per example,
no target index is necessary.
For general 2D outputs, targets can be either:
- a single integer or a tensor containing a single
integer, which is applied to all input examples
- a list of integers or a 1D tensor, with length matching
the number of examples in inputs (dim 0). Each integer
is applied as the target for the corresponding example.
For outputs with > 2 dimensions, targets can be either:
- A single tuple, which contains #output_dims - 1
elements. This target index is applied to all examples.
- A list of tuples with length equal to the number of
examples in inputs (dim 0), and each tuple containing
#output_dims - 1 elements. Each tuple is applied as the
target for the corresponding example.
Default: None
additional_forward_args (tuple, optional): If the forward function
requires additional arguments other than the inputs for
which attributions should not be computed, this argument
can be provided. It must be either a single additional
argument of a Tensor or arbitrary (non-tuple) type or a
tuple containing multiple additional arguments including
tensors or any arbitrary python types. These arguments
are provided to forward_func in order following the
arguments in inputs.
For a tensor, the first dimension of the tensor must
correspond to the number of examples. For all other types,
the given argument is used for all forward evaluations.
Note that attributions are not computed with respect
to these arguments.
Default: None
feature_mask (tensor or tuple of tensors, optional):
feature_mask defines a mask for the input, grouping
features which should be ablated together. feature_mask
should contain the same number of tensors as inputs.
Each tensor should
be the same size as the corresponding input or
broadcastable to match the input tensor. Each tensor
should contain integers in the range 0 to num_features
- 1, and indices corresponding to the same feature should
have the same value.
Note that features within each input tensor are ablated
independently (not across tensors).
If the forward function returns a single scalar per batch,
we enforce that the first dimension of each mask must be 1,
since attributions are returned batch-wise rather than per
example, so the attributions must correspond to the
same features (indices) in each input example.
If None, then a feature mask is constructed which assigns
each scalar within a tensor as a separate feature, which
is ablated independently.
Default: None
ablations_per_eval (int, optional): Allows ablation of multiple features
to be processed simultaneously in one call to forward_fn.
Each forward pass will contain a maximum of
ablations_per_eval * #examples samples.
For DataParallel models, each batch is split among the
available devices, so evaluations on each available
device contain at most
(ablations_per_eval * #examples) / num_devices
samples.
If the forward function returns a single scalar per batch,
ablations_per_eval must be set to 1.
Default: 1
**kwargs (Any, optional): Any additional arguments used by child
classes of FeatureAblation (such as Occlusion) to construct
ablations. These arguments are ignored when using
FeatureAblation directly.
Default: None
Returns:
*tensor* or tuple of *tensors* of **attributions**:
- **attributions** (*tensor* or tuple of *tensors*):
The attributions with respect to each input feature.
If the forward function returns
a scalar value per example, attributions will be
the same size as the provided inputs, with each value
providing the attribution of the corresponding input index.
If the forward function returns a scalar per batch, then
attribution tensor(s) will have first dimension 1 and
the remaining dimensions will match the input.
If a single tensor is provided as inputs, a single tensor is
returned. If a tuple is provided for inputs, a tuple of
corresponding sized tensors is returned.
Examples::
>>> # SimpleClassifier takes a single input tensor of size Nx4x4,
>>> # and returns an Nx3 tensor of class probabilities.
>>> net = SimpleClassifier()
>>> # Generating random input with size 2 x 4 x 4
>>> input = torch.randn(2, 4, 4)
>>> # Defining FeatureAblation interpreter
>>> ablator = FeatureAblation(net)
>>> # Computes ablation attribution, ablating each of the 16
>>> # scalar input independently.
>>> attr = ablator.attribute(input, target=1)
>>> # Alternatively, we may want to ablate features in groups, e.g.
>>> # grouping each 2x2 square of the inputs and ablating them together.
>>> # This can be done by creating a feature mask as follows, which
>>> # defines the feature groups, e.g.:
>>> # +---+---+---+---+
>>> # | 0 | 0 | 1 | 1 |
>>> # +---+---+---+---+
>>> # | 0 | 0 | 1 | 1 |
>>> # +---+---+---+---+
>>> # | 2 | 2 | 3 | 3 |
>>> # +---+---+---+---+
>>> # | 2 | 2 | 3 | 3 |
>>> # +---+---+---+---+
>>> # With this mask, all inputs with the same value are ablated
>>> # simultaneously, and the attribution for each input in the same
>>> # group (0, 1, 2, and 3) per example are the same.
>>> # The attributions can be calculated as follows:
>>> # feature mask has dimensions 1 x 4 x 4
>>> feature_mask = torch.tensor([[[0,0,1,1],[0,0,1,1],
>>> [2,2,3,3],[2,2,3,3]]])
>>> attr = ablator.attribute(input, target=1, feature_mask=feature_mask)
"""
with torch.no_grad():
# Keeps track whether original input is a tuple or not before
# converting it into a tuple.
is_inputs_tuple = isinstance(inputs, tuple)
inputs, baselines = _format_input_baseline(inputs, baselines)
additional_forward_args = _format_additional_forward_args(
additional_forward_args
)
num_examples = inputs[0].shape[0]
feature_mask = (
_format_input(feature_mask) if feature_mask is not None else None
)
assert (
isinstance(ablations_per_eval, int) and ablations_per_eval >= 1
), "Ablations per evaluation must be at least 1."
# Computes initial evaluation with all features, which is compared
# to each ablated result.
initial_eval = _run_forward(
self.forward_func, inputs, target, additional_forward_args
)
if isinstance(initial_eval, (int, float)) or (
isinstance(initial_eval, torch.Tensor)
and (
len(initial_eval.shape) == 0
or (num_examples > 1 and initial_eval.numel() == 1)
)
):
single_output_mode = True
assert (
ablations_per_eval == 1
), "Cannot have ablations_per_eval > 1 when function returns scalar."
if feature_mask is not None:
for single_mask in feature_mask:
assert single_mask.shape[0] == 1, (
"Cannot provide multiple masks when function returns"
" a scalar."
)
else:
single_output_mode = False
assert (
isinstance(initial_eval, torch.Tensor)
and initial_eval[0].numel() == 1
), "Target should identify a single element in the model output."
initial_eval = initial_eval.reshape(1, num_examples)
# Initialize attribution totals and counts
attrib_type = (
initial_eval.dtype
if isinstance(initial_eval, torch.Tensor)
else type(initial_eval)
)
total_attrib = [
torch.zeros_like(
input[0:1] if single_output_mode else input, dtype=attrib_type
)
for input in inputs
]
# Weights are used in cases where ablations may be overlapping.
if self.use_weights:
weights = [
torch.zeros_like(
input[0:1] if single_output_mode else input
).float()
for input in inputs
]
# Iterate through each feature tensor for ablation
for i in range(len(inputs)):
for (
current_inputs,
current_add_args,
current_target,
current_mask,
) in self._ablation_generator(
i,
inputs,
additional_forward_args,
target,
baselines,
feature_mask,
ablations_per_eval,
**kwargs
):
# modified_eval dimensions: 1D tensor with length
# equal to #num_examples * #features in batch
modified_eval = _run_forward(
self.forward_func,
current_inputs,
current_target,
current_add_args,
)
# eval_diff dimensions: (#features in batch, #num_examples, 1,.. 1)
# (contains 1 more dimension than inputs). This adds extra
# dimensions of 1 to make the tensor broadcastable with the inputs
# tensor.
if single_output_mode:
eval_diff = initial_eval - modified_eval
else:
eval_diff = (
initial_eval - modified_eval.reshape(-1, num_examples)
).reshape(
(-1, num_examples) + (len(inputs[i].shape) - 1) * (1,)
)
if self.use_weights:
weights[i] += current_mask.float().sum(dim=0)
total_attrib[i] += (eval_diff * current_mask.to(attrib_type)).sum(
dim=0
)
# Divide total attributions by counts and return formatted attributions
if self.use_weights:
attrib = tuple(
single_attrib.float() / weight
for single_attrib, weight in zip(total_attrib, weights)
)
else:
attrib = tuple(total_attrib)
return _format_attributions(is_inputs_tuple, attrib)
def _ablation_generator(
self,
i,
inputs,
additional_args,
target,
baselines,
input_mask,
ablations_per_eval,
**kwargs
):
extra_args = {}
for key, value in kwargs.items():
# For any tuple argument in kwargs, we choose index i of the tuple.
if isinstance(value, tuple):
extra_args[key] = value[i]
else:
extra_args[key] = value
input_mask = input_mask[i] if input_mask is not None else None
min_feature, num_features, input_mask = self._get_feature_range_and_mask(
inputs[i], input_mask, **extra_args
)
num_examples = inputs[0].shape[0]
ablations_per_eval = min(ablations_per_eval, num_features)
baseline = baselines[i] if isinstance(baselines, tuple) else baselines
if isinstance(baseline, torch.Tensor):
baseline = baseline.reshape((1,) + baseline.shape)
# Repeat features and additional args for batch size.
all_features_repeated = [
torch.cat([inputs[j]] * ablations_per_eval, dim=0)
for j in range(len(inputs))
]
additional_args_repeated = (
_expand_additional_forward_args(additional_args, ablations_per_eval)
if additional_args is not None
else None
)
target_repeated = _expand_target(target, ablations_per_eval)
num_features_processed = min_feature
while num_features_processed < num_features:
current_num_ablated_features = min(
ablations_per_eval, num_features - num_features_processed
)
# Store appropriate inputs and additional args based on batch size.
if current_num_ablated_features != ablations_per_eval:
current_features = [
feature_repeated[0 : current_num_ablated_features * num_examples]
for feature_repeated in all_features_repeated
]
current_additional_args = (
_expand_additional_forward_args(
additional_args, current_num_ablated_features
)
if additional_args is not None
else None
)
current_target = _expand_target(target, current_num_ablated_features)
else:
current_features = all_features_repeated
current_additional_args = additional_args_repeated
current_target = target_repeated
# Store existing tensor before modifying
original_tensor = current_features[i]
# Construct ablated batch for features in range num_features_processed
# to num_features_processed + current_num_ablated_features and return
# mask with same size as ablated batch. ablated_features has dimension
# (current_num_ablated_features, num_examples, + inputs[i].shape[1:])
ablated_features, current_mask = self._construct_ablated_input(
current_features[i].reshape(
(current_num_ablated_features, num_examples)
+ current_features[i].shape[1:]
),
input_mask,
baseline,
num_features_processed,
num_features_processed + current_num_ablated_features,
**extra_args
)
# current_features[i] has dimension
# (current_num_ablated_features * num_examples, inputs[i].shape[1:]),
# which can be provided to the model as input.
current_features[i] = ablated_features.reshape(
(-1,) + ablated_features.shape[2:]
)
yield tuple(
current_features
), current_additional_args, current_target, current_mask
# Replace existing tensor at index i.
current_features[i] = original_tensor
num_features_processed += current_num_ablated_features
def _construct_ablated_input(
self, expanded_input, input_mask, baseline, start_feature, end_feature, **kwargs
):
r"""
Ablates given expanded_input tensor with given feature mask, feature range,
and baselines. expanded_input shape is (`num_features`, `num_examples`, ...)
with remaining dimensions corresponding to remaining original tensor
dimensions and `num_features` = `end_feature` - `start_feature`.
input_mask has same number of dimensions as original input tensor (one less
than `expanded_input`), and can have first dimension either 1, applying same
feature mask to all examples, or `num_examples`. baseline is expected to
be broadcastable to match `expanded_input`.
This method returns the ablated input tensor, which has the same
dimensionality as `expanded_input` as well as the corresponding mask with
either the same dimensionality as `expanded_input` or second dimension
being 1. This mask contains 1s in locations which have been ablated (and
thus counted towards ablations for that feature) and 0s otherwise.
"""
current_mask = torch.stack(
[input_mask == j for j in range(start_feature, end_feature)], dim=0
).long()
ablated_tensor = (
expanded_input * (1 - current_mask).to(expanded_input.dtype)
) + (baseline * current_mask.to(expanded_input.dtype))
return ablated_tensor, current_mask
def _get_feature_range_and_mask(self, input, input_mask, **kwargs):
if input_mask is None:
# Obtain feature mask for selected input tensor, matches size of
# 1 input example, (1 x inputs[i].shape[1:])
input_mask = torch.reshape(
torch.arange(torch.numel(input[0]), device=input.device),
input[0:1].shape,
).long()
return (
torch.min(input_mask).item(),
torch.max(input_mask).item() + 1,
input_mask,
)