-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathdata.py
188 lines (158 loc) Β· 9.02 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import warnings
from dataclasses import dataclass
from enum import Enum
from typing import Union, List, Optional, Tuple, Any
import numpy as np
import torch
from PIL.Image import Image
from diffusers_interpret.generated_images import GeneratedImages
from diffusers_interpret.pixel_attributions import PixelAttributions
from diffusers_interpret.saliency_map import SaliencyMap
from diffusers_interpret.token_attributions import TokenAttributions
@dataclass
class BaseMimicPipelineCallOutput:
"""
Output class for BasePipelineExplainer._mimic_pipeline_call
Args:
images (`List[Image]` or `torch.Tensor`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
nsfw_content_detected (`Optional[List[bool]]`)
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content.
all_images_during_generation (`Optional[Union[List[List[Image]]], List[torch.Tensor]]`)
A list with all the batch images generated during diffusion
"""
images: Union[List[Image], torch.Tensor]
nsfw_content_detected: Optional[List[bool]] = None
all_images_during_generation: Optional[Union[List[List[Image]], List[torch.Tensor]]] = None
def __getitem__(self, item):
return getattr(self, item)
def __setitem__(self, key, value):
setattr(self, key, value)
@dataclass
class PipelineExplainerOutput:
"""
Output class for BasePipelineExplainer.__call__ if `init_image=None` and `explanation_2d_bounding_box=None`
Args:
image (`Image` or `torch.Tensor`)
The denoised PIL output image or torch.Tensor of shape `(height, width, num_channels)`.
nsfw_content_detected (`Optional[bool]`)
A flag denoting whether the generated image likely represents "not-safe-for-work"
(nsfw) content.
all_images_during_generation (`Optional[Union[GeneratedImages, List[torch.Tensor]]]`)
A GeneratedImages object to visualize all the generated images during diffusion OR a list of tensors of those images
token_attributions (`Optional[TokenAttributions]`)
TokenAttributions that contains a list of tuples with (token, token_attribution)
"""
image: Union[Image, torch.Tensor]
nsfw_content_detected: Optional[bool] = None
all_images_during_generation: Optional[Union[GeneratedImages, List[torch.Tensor]]] = None
token_attributions: Optional[TokenAttributions] = None
def __getitem__(self, item):
return getattr(self, item)
def __setitem__(self, key, value):
setattr(self, key, value)
def __getattr__(self, attr):
if attr == 'normalized_token_attributions':
warnings.warn(
f"`normalized_token_attributions` is deprecated as an attribute of `{self.__class__.__name__}` "
f"and will be removed in a future version. Consider using `output.token_attributions.normalized` instead",
DeprecationWarning, stacklevel=2
)
return self.token_attributions.normalized
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{attr}'")
@dataclass
class PipelineExplainerForBoundingBoxOutput(PipelineExplainerOutput):
"""
Output class for BasePipelineExplainer.__call__ if `init_image=None` and `explanation_2d_bounding_box is not None`
Args:
image (`Image` or `torch.Tensor`)
The denoised PIL output image or torch.Tensor of shape `(height, width, num_channels)`.
nsfw_content_detected (`Optional[bool]`)
A flag denoting whether the generated image likely represents "not-safe-for-work"
(nsfw) content.
all_images_during_generation (`Optional[Union[GeneratedImages, List[torch.Tensor]]]`)
A GeneratedImages object to visualize all the generated images during diffusion OR a list of tensors of those images
token_attributions (`Optional[TokenAttributions]`)
TokenAttributions that contains a list of tuples with (token, token_attribution)
explanation_2d_bounding_box: (`Tuple[Tuple[int, int], Tuple[int, int]]`)
Tuple with the bounding box coordinates where the attributions were calculated for.
The tuple is like (upper left corner, bottom right corner). Example: `((0, 0), (300, 300))`
"""
explanation_2d_bounding_box: Tuple[Tuple[int, int], Tuple[int, int]] = None # (upper left corner, bottom right corner)
@dataclass
class PipelineImg2ImgExplainerOutput(PipelineExplainerOutput):
"""
Output class for BasePipelineExplainer.__call__ if `init_image is not None` and `explanation_2d_bounding_box=None`
Args:
image (`Image` or `torch.Tensor`)
The denoised PIL output image or torch.Tensor of shape `(height, width, num_channels)`.
nsfw_content_detected (`Optional[bool]`)
A flag denoting whether the generated image likely represents "not-safe-for-work"
(nsfw) content.
all_images_during_generation (`Optional[Union[GeneratedImages, List[torch.Tensor]]]`)
A GeneratedImages object to visualize all the generated images during diffusion OR a list of tensors of those images
token_attributions (`Optional[TokenAttributions]`)
TokenAttributions that contains a list of tuples with (token, token_attribution)
pixel_attributions (`Optional[PixelAttributions]`)
PixelAttributions that is a numpy array of shape `(height, width)` with an attribution score per pixel in the input image
input_saliency_map (`Optional[SaliencyMap]`)
A SaliencyMap object to visualize the pixel attributions of the input image
"""
pixel_attributions: Optional[PixelAttributions] = None
def __getattr__(self, attr):
if attr == 'normalized_pixel_attributions':
warnings.warn(
f"`normalized_pixel_attributions` is deprecated as an attribute of `{self.__class__.__name__}` "
f"and will be removed in a future version. Consider using `output.pixel_attributions.normalized` instead",
DeprecationWarning, stacklevel=2
)
return self.token_attributions.normalized
elif attr == 'input_saliency_map':
return self.pixel_attributions.saliency_map
return super().__getattr__(attr)
@dataclass
class PipelineImg2ImgExplainerForBoundingBoxOutputOutput(PipelineExplainerForBoundingBoxOutput, PipelineImg2ImgExplainerOutput):
"""
Output class for BasePipelineExplainer.__call__ if `init_image is not None` and `explanation_2d_bounding_box=None`
Args:
image (`Image` or `torch.Tensor`)
The denoised PIL output image or torch.Tensor of shape `(height, width, num_channels)`.
nsfw_content_detected (`Optional[bool]`)
A flag denoting whether the generated image likely represents "not-safe-for-work"
(nsfw) content.
all_images_during_generation (`Optional[Union[GeneratedImages, List[torch.Tensor]]]`)
A GeneratedImages object to visualize all the generated images during diffusion OR a list of tensors of those images
token_attributions (`Optional[TokenAttributions]`)
TokenAttributions that contains a list of tuples with (token, token_attribution)
pixel_attributions (`Optional[np.ndarray]`)
PixelAttributions that is a numpy array of shape `(height, width)` with an attribution score per pixel in the input image
input_saliency_map (`Optional[SaliencyMap]`)
A SaliencyMap object to visualize the pixel attributions of the input image
explanation_2d_bounding_box: (`Tuple[Tuple[int, int], Tuple[int, int]]`)
Tuple with the bounding box coordinates where the attributions were calculated for.
The tuple is like (upper left corner, bottom right corner). Example: `((0, 0), (300, 300))`
"""
pass
class ExplicitEnum(str, Enum):
"""
Enum with more explicit error message for missing values.
"""
@classmethod
def _missing_(cls, value):
raise ValueError(
f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
)
class AttributionAlgorithm(ExplicitEnum):
"""
Possible values for `tokens_attribution_method` and `pixels_attribution_method` arguments in `AttributionMethods`
"""
GRAD_X_INPUT = "grad_x_input"
MAX_GRAD = "max_grad"
MEAN_GRAD = "mean_grad"
MIN_GRAD = "min_grad"
@dataclass
class AttributionMethods:
tokens_attribution_method: Union[str, AttributionAlgorithm] = AttributionAlgorithm.GRAD_X_INPUT
pixels_attribution_method: Optional[Union[str, AttributionAlgorithm]] = AttributionAlgorithm.MAX_GRAD