-
Notifications
You must be signed in to change notification settings - Fork 1.6k
/
Copy pathablation_cam.py
147 lines (124 loc) · 6.6 KB
/
ablation_cam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import numpy as np
import torch
import tqdm
from typing import Callable, List
from pytorch_grad_cam.base_cam import BaseCAM
from pytorch_grad_cam.utils.find_layers import replace_layer_recursive
from pytorch_grad_cam.ablation_layer import AblationLayer
""" Implementation of AblationCAM
https://openaccess.thecvf.com/content_WACV_2020/papers/Desai_Ablation-CAM_Visual_Explanations_for_Deep_Convolutional_Network_via_Gradient-free_Localization_WACV_2020_paper.pdf
Ablate individual activations, and then measure the drop in the target scores.
In the current implementation, the target layer activations is cached, so it won't be re-computed.
However layers before it, if any, will not be cached.
This means that if the target layer is a large block, for example model.featuers (in vgg), there will
be a large save in run time.
Since we have to go over many channels and ablate them, and every channel ablation requires a forward pass,
it would be nice if we could avoid doing that for channels that won't contribute anwyay, making it much faster.
The parameter ratio_channels_to_ablate controls how many channels should be ablated, using an experimental method
(to be improved). The default 1.0 value means that all channels will be ablated.
"""
class AblationCAM(BaseCAM):
def __init__(self,
model: torch.nn.Module,
target_layers: List[torch.nn.Module],
reshape_transform: Callable = None,
ablation_layer: torch.nn.Module = AblationLayer(),
batch_size: int = 32,
ratio_channels_to_ablate: float = 1.0) -> None:
super(AblationCAM, self).__init__(model,
target_layers,
reshape_transform,
uses_gradients=False)
self.batch_size = batch_size
self.ablation_layer = ablation_layer
self.ratio_channels_to_ablate = ratio_channels_to_ablate
def save_activation(self, module, input, output) -> None:
""" Helper function to save the raw activations from the target layer """
self.activations = output
def assemble_ablation_scores(self,
new_scores: list,
original_score: float,
ablated_channels: np.ndarray,
number_of_channels: int) -> np.ndarray:
""" Take the value from the channels that were ablated,
and just set the original score for the channels that were skipped """
index = 0
result = []
sorted_indices = np.argsort(ablated_channels)
ablated_channels = ablated_channels[sorted_indices]
new_scores = np.float32(new_scores)[sorted_indices]
for i in range(number_of_channels):
if index < len(ablated_channels) and ablated_channels[index] == i:
weight = new_scores[index]
index = index + 1
else:
weight = original_score
result.append(weight)
return result
def get_cam_weights(self,
input_tensor: torch.Tensor,
target_layer: torch.nn.Module,
targets: List[Callable],
activations: torch.Tensor,
grads: torch.Tensor) -> np.ndarray:
# Do a forward pass, compute the target scores, and cache the
# activations
handle = target_layer.register_forward_hook(self.save_activation)
with torch.no_grad():
outputs = self.model(input_tensor)
handle.remove()
original_scores = np.float32(
[target(output).cpu().item() for target, output in zip(targets, outputs)])
# Replace the layer with the ablation layer.
# When we finish, we will replace it back, so the
# original model is unchanged.
ablation_layer = self.ablation_layer
replace_layer_recursive(self.model, target_layer, ablation_layer)
number_of_channels = activations.shape[1]
weights = []
# This is a "gradient free" method, so we don't need gradients here.
with torch.no_grad():
# Loop over each of the batch images and ablate activations for it.
for batch_index, (target, tensor) in enumerate(
zip(targets, input_tensor)):
new_scores = []
batch_tensor = tensor.repeat(self.batch_size, 1, 1, 1)
# Check which channels should be ablated. Normally this will be all channels,
# But we can also try to speed this up by using a low
# ratio_channels_to_ablate.
channels_to_ablate = ablation_layer.activations_to_be_ablated(
activations[batch_index, :], self.ratio_channels_to_ablate)
number_channels_to_ablate = len(channels_to_ablate)
for i in tqdm.tqdm(
range(
0,
number_channels_to_ablate,
self.batch_size)):
if i + self.batch_size > number_channels_to_ablate:
batch_tensor = batch_tensor[:(
number_channels_to_ablate - i)]
# Change the state of the ablation layer so it ablates the next channels.
# TBD: Move this into the ablation layer forward pass.
ablation_layer.set_next_batch(
input_batch_index = batch_index,
activations = self.activations,
num_channels_to_ablate = batch_tensor.size(0))
score = [target(o).cpu().item()
for o in self.model(batch_tensor)]
new_scores.extend(score)
ablation_layer.indices = ablation_layer.indices[batch_tensor.size(
0):]
new_scores = self.assemble_ablation_scores(
new_scores,
original_scores[batch_index],
channels_to_ablate,
number_of_channels)
weights.extend(new_scores)
weights = np.float32(weights)
weights = weights.reshape(activations.shape[:2])
original_scores = original_scores[:, None]
weights = (original_scores - weights) / original_scores
# Replace the model back to the original state
replace_layer_recursive(self.model, ablation_layer, target_layer)
# Returning the weights from new_scores
return weights