-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathpixel_attributions.py
28 lines (19 loc) · 1005 Bytes
/
pixel_attributions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from typing import Any, Union
import numpy as np
from diffusers_interpret.saliency_map import SaliencyMap
class PixelAttributions(np.ndarray):
def __new__(cls, pixel_attributions: np.ndarray, saliency_map: SaliencyMap) -> "PixelAttributions":
# Construct new ndarray
obj = np.asarray(pixel_attributions).view(cls)
obj.pixel_attributions = pixel_attributions
obj.normalized = 100 * (pixel_attributions / pixel_attributions.sum())
obj.saliency_map = saliency_map
# Calculate normalized
obj.normalized = 100 * (pixel_attributions / pixel_attributions.sum())
return obj
def __getitem__(self, item: Union[str, int]) -> Any:
return getattr(self, item) if isinstance(item, str) else self.pixel_attributions[item]
def __setitem__(self, key: Union[str, int], value: Any) -> None:
setattr(self, key, value)
def __repr__(self) -> str:
return self.pixel_attributions.__repr__()