Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
d600559
Fix the Inception5h model's download link
ProGamerGov Apr 26, 2021
8612fee
Fix black error
ProGamerGov Apr 26, 2021
d31d60d
Ensure history has no gradient
ProGamerGov Apr 26, 2021
bd3c7fe
Remove .cpu() to improve optimization speed
ProGamerGov Apr 26, 2021
bbebf22
Temporarily disable nightly build tests affected by pytorch/pytorch#5…
ProGamerGov May 3, 2021
0fb9eb5
Fix flake8 error
ProGamerGov May 3, 2021
7c453e0
Remove ImageTensor test skips & add new tests
ProGamerGov May 6, 2021
5d0d143
Fix ImageTensor __new__ list test
ProGamerGov May 6, 2021
50455bf
Fix NaturalImage device bug
ProGamerGov May 12, 2021
c9ece95
Set decorrelate_init default back to True
ProGamerGov May 12, 2021
4bd8e5b
Check for presence of Pillow / PIL library in ImageTensor applicable …
ProGamerGov May 12, 2021
71a370d
Update Conda installation script to latest version
ProGamerGov May 17, 2021
e7fcdc0
Add SkipLayer to models __init__
ProGamerGov May 19, 2021
015890e
Make SkipLayer work if there are any additional init or forward argum…
ProGamerGov May 19, 2021
1bcac93
Minor correction to optimize's loss summarizer setup
ProGamerGov May 20, 2021
811a269
Fix _rand_select bug
ProGamerGov May 21, 2021
1f2d421
Increase number of steps in optimization test
ProGamerGov May 22, 2021
8227b79
Fix FFTImage support for images with odd width values
ProGamerGov May 22, 2021
f10fa86
Fix test_rfft2d_freqs & add more SkipLayer tests
ProGamerGov May 22, 2021
9157524
Remove duplicate imports
ProGamerGov May 23, 2021
6359321
Make it possible to load RGBA images with ImageTensor
ProGamerGov May 23, 2021
e4a1310
Remove unused line: 'h, w = self.size'
ProGamerGov May 24, 2021
9ca0f9b
Change NumPy rfft2d_freqs to match PyTorch version
ProGamerGov May 26, 2021
681e3ee
Fix PCA ChannelReducer test
ProGamerGov May 26, 2021
8aaf1a2
Fix failing nodejs
ProGamerGov May 26, 2021
27145c7
Resolve the ToRGB device issue with NaturalImage
ProGamerGov May 26, 2021
15a7db6
Fix no color decorrelation test
ProGamerGov May 26, 2021
02fc2d7
Add ToDos and better SkipLayer docs
ProGamerGov Jun 1, 2021
7f078f9
Fix lint and FFTImage init device
ProGamerGov Jun 1, 2021
d9b4620
Fix NaturalImage device issues
ProGamerGov Jun 1, 2021
5aaece3
Improve NaturalImage fix
ProGamerGov Jun 1, 2021
c72adc5
Remove redundant NaturalImage device fix
ProGamerGov Jun 2, 2021
1eae61c
Improve SkipLayer documentation
ProGamerGov Jun 7, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions captum/attr/_core/feature_ablation.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def attribute(
additional_forward_args: Any = None,
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
perturbations_per_eval: int = 1,
**kwargs: Any
**kwargs: Any,
) -> TensorOrTupleOfTensorsGeneric:
r"""
Args:
Expand Down Expand Up @@ -321,7 +321,7 @@ def attribute(
baselines,
feature_mask,
perturbations_per_eval,
**kwargs
**kwargs,
):
# modified_eval dimensions: 1D tensor with length
# equal to #num_examples * #features in batch
Expand Down Expand Up @@ -373,7 +373,7 @@ def _ablation_generator(
baselines,
input_mask,
perturbations_per_eval,
**kwargs
**kwargs,
):
"""
This method is a generator which yields each perturbation to be evaluated
Expand Down Expand Up @@ -458,7 +458,7 @@ def _ablation_generator(
baseline,
num_features_processed,
num_features_processed + current_num_ablated_features,
**extra_args
**extra_args,
)

# current_features[i] has dimension
Expand Down
4 changes: 2 additions & 2 deletions captum/metrics/_core/infidelity.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def sub_infidelity_perturb_func_decorator(pertub_func: Callable) -> Callable:
def default_perturb_func(
inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None
):
r""""""
r""" """
inputs_perturbed = (
pertub_func(inputs, baselines)
if baselines is not None
Expand Down Expand Up @@ -380,7 +380,7 @@ def _generate_perturbations(
"""

def call_perturb_func():
r""""""
r""" """
baselines_pert = None
inputs_pert: Union[Tensor, Tuple[Tensor, ...]]
if len(inputs_expanded) == 1:
Expand Down
1 change: 0 additions & 1 deletion captum/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
show,
weights_to_heatmap_2d,
)
from captum.optim._utils.reducer import ChannelReducer, posneg # noqa: F401

__all__ = [
"InputOptimization",
Expand Down
5 changes: 3 additions & 2 deletions captum/optim/_core/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def optimize(
self,
stop_criteria: Optional[StopCriteria] = None,
optimizer: Optional[optim.Optimizer] = None,
loss_summarize_fn: Optional[Callable] = default_loss_summarize,
loss_summarize_fn: Optional[Callable] = None,
lr: float = 0.025,
) -> torch.Tensor:
r"""Optimize input based on loss function and objectives.
Expand All @@ -131,14 +131,15 @@ def optimize(
stop_criteria = stop_criteria or n_steps(512)
optimizer = optimizer or optim.Adam(self.parameters(), lr=lr)
assert isinstance(optimizer, optim.Optimizer)
loss_summarize_fn = loss_summarize_fn or default_loss_summarize

history = []
step = 0
try:
while stop_criteria(step, self, history, optimizer):
optimizer.zero_grad()
loss_value = loss_summarize_fn(self.loss())
history.append(loss_value)
history.append(loss_value.clone().detach())
loss_value.backward()
optimizer.step()
step += 1
Expand Down
21 changes: 11 additions & 10 deletions captum/optim/_param/image/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ def __new__(
return super().__new__(cls, x, *args, **kwargs)

@classmethod
def open(cls, path: str, scale: float = 255.0) -> "ImageTensor":
def open(cls, path: str, scale: float = 255.0, mode: str = "RGB") -> "ImageTensor":
if path.startswith("https://") or path.startswith("http://"):
response = requests.get(path, stream=True)
img = Image.open(response.raw)
else:
img = Image.open(path)
img_np = np.array(img.convert("RGB")).astype(np.float32)
img_np = np.array(img.convert(mode)).astype(np.float32)
return cls(img_np.transpose(2, 0, 1) / scale)

def __repr__(self) -> str:
Expand Down Expand Up @@ -116,7 +116,6 @@ def __init__(
)
scale = scale * ((self.size[0] * self.size[1]) ** (1 / 2))
spectrum_scale = scale[None, :, :, None]
self.register_buffer("spectrum_scale", spectrum_scale)

if init is None:
coeffs_shape = (
Expand All @@ -131,16 +130,16 @@ def __init__(
) # names=["C", "H_f", "W_f", "complex"]
fourier_coeffs = random_coeffs / 50
else:
spectrum_scale = spectrum_scale.to(init.device)
fourier_coeffs = self.torch_rfft(init) / spectrum_scale

self.register_buffer("spectrum_scale", spectrum_scale)
self.fourier_coeffs = nn.Parameter(fourier_coeffs)

def rfft2d_freqs(self, height: int, width: int) -> torch.Tensor:
"""Computes 2D spectrum frequencies."""
fy = self.torch_fftfreq(height)[:, None]
# on odd input dimensions we need to keep one additional frequency
wadd = 2 if width % 2 == 1 else 1
fx = self.torch_fftfreq(width)[: width // 2 + wadd]
fx = self.torch_fftfreq(width)[: width // 2 + 1]
return torch.sqrt((fx * fx) + (fy * fy))

def get_fft_funcs(self) -> Tuple[Callable, Callable, Callable]:
Expand Down Expand Up @@ -181,7 +180,6 @@ def torch_fftfreq(v: int, d: float = 1.0) -> torch.Tensor:
return torch_rfft, torch_irfft, torch_fftfreq

def forward(self) -> torch.Tensor:
h, w = self.size
scaled_spectrum = self.fourier_coeffs * self.spectrum_scale
output = self.torch_irfft(scaled_spectrum)
return output.refine_names("B", "C", "H", "W")
Expand Down Expand Up @@ -212,6 +210,10 @@ def forward(self) -> torch.Tensor:


class LaplacianImage(ImageParameterization):
"""
TODO: Fix divison by 6 in setup_input when init is not None.
"""

def __init__(
self,
size: Tuple[int, int] = None,
Expand Down Expand Up @@ -418,7 +420,7 @@ class NaturalImage(ImageParameterization):

def __init__(
self,
size: Tuple[int, int] = [224, 224],
size: Tuple[int, int] = (224, 224),
channels: int = 3,
batch: int = 1,
init: Optional[torch.Tensor] = None,
Expand All @@ -431,8 +433,7 @@ def __init__(
self.decorrelate = decorrelation_module
if init is not None:
assert init.dim() == 3 or init.dim() == 4
if decorrelate_init:
assert self.decorrelate is not None
if decorrelate_init and self.decorrelate is not None:
init = (
init.refine_names("B", "C", "H", "W")
if init.dim() == 4
Expand Down
55 changes: 4 additions & 51 deletions captum/optim/_param/image/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ def forward(self, x: torch.Tensor, inverse: bool = False) -> torch.Tensor:
h, w = x.size("H"), x.size("W")
flat = x.flatten(("H", "W"), "spatials")
if inverse:
correct = torch.inverse(self.transform) @ flat
correct = torch.inverse(self.transform.to(x.device)) @ flat
else:
correct = self.transform @ flat
correct = self.transform.to(x.device) @ flat
chw = correct.unflatten("spatials", (("H", h), ("W", w)))

if x.dim() == 3:
Expand Down Expand Up @@ -217,9 +217,9 @@ def _rand_select(
transform_values: NumSeqOrTensorType,
) -> Union[int, float, torch.Tensor]:
"""
Randomly return a value from the provided tuple or list
Randomly return a single value from the provided tuple, list, or tensor.
"""
n = torch.randint(low=0, high=len(transform_values) - 1, size=[1]).item()
n = torch.randint(low=0, high=len(transform_values), size=[1]).item()
return transform_values[n]


Expand Down Expand Up @@ -503,53 +503,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
)


class AlphaChannelLoss(nn.Module):
"""
TODO: Fix AlphaChannelLoss
Transform for calculating alpha channel loss, without altering the input tensor.
Loss values are calculated in such a way that opaque and transparent regions of
the tensor are automatically balanced.
See: https://distill.pub/2018/differentiable-parameterizations/
Mordvintsev, et al., "Differentiable Image Parameterizations", Distill, 2018.
Args:
scale (float, sequence): Tuple of rescaling values to randomly select from.
crop_size (int, sequence, int, optional): The desired cropped output size
for secondary alpha channel loss.
background (tensor, optional): An NCHW image tensor to be used as the
alpha channel's background.
"""

def __init__(
self,
scale: NumSeqOrTensorType,
crop_size: Optional[Tuple[int, int]] = None,
background: Optional[torch.Tensor] = None,
) -> None:
raise NotImplementedError # We are not ready for this
super().__init__()
self.random_scale = RandomScale(scale=scale)
self.crop_size = crop_size
self.random_crop = RandomCrop(crop_size)
self.blend_alpha = BlendAlpha(background=background)
self.loss = 0

def forward(self, x: torch.Tensor) -> torch.Tensor:
assert x.dim() == 4 # Should be of shape (batch, channel, height, width)
assert x.size(1) == 4 # Channel dim should be rgba

x_shifted = torch.cat([self.blend_alpha(x.clone()), x.clone()[:, 3:]], 1)

x_shifted = self.random_scale(x_shifted)
x_shifted_crop = self.random_crop(x_shifted)

self.loss = (1.0 - x_shifted[:, 3:].mean()) + (
(1.0 - x_shifted_crop[:, 3:].mean()) * 0.5
)
return x


__all__ = [
"BlendAlpha",
"IgnoreAlpha",
Expand Down
2 changes: 2 additions & 0 deletions captum/optim/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ._common import ( # noqa: F401
RedirectedReluLayer,
SkipLayer,
collect_activations,
get_model_layers,
replace_layers,
Expand All @@ -10,6 +11,7 @@

__all__ = [
"RedirectedReluLayer",
"SkipLayer",
"collect_activations",
"get_model_layers",
"replace_layers",
Expand Down
31 changes: 29 additions & 2 deletions captum/optim/models/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,37 @@ def collect_activations(

class SkipLayer(torch.nn.Module):
"""
This layer is made to take the place of nonlinear activation layers like ReLU.
This layer is made to take the place of any layer that needs to be skipped over
during the forward pass. Use cases include removing nonlinear activation layers
like ReLU for circuits research.

This layer works almost exactly the same way that nn.Indentiy does, except it also
ignores any additional arguments passed to the forward function. Any layer replaced
by SkipLayer must have the same input and output shapes.

See nn.Identity for more details:
https://pytorch.org/docs/stable/generated/torch.nn.Identity.html

Args:
args (Any): Any argument. Arguments will be safely ignored.
kwargs (Any) Any keyword argument. Arguments will be safely ignored.
"""

def forward(self, x: torch.Tensor) -> torch.Tensor:
def __init__(self, *args, **kwargs) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are the args and kwargs passed if they are not used ? Are there any real scenarios that we need it ? Is this because we replace ReLU with SkipLayer ?
I could imagine, for example, that x is a tuple of tensors and we would need to return that tuple.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made the change because users may have models using activ(inplace=False) or activ(False) where activ = torch.nn.ReLU. This is same way that torch.nn.Identity works: https://pytorch.org/docs/stable/generated/torch.nn.Identity.html

I'll add the type hint for tuples of tensors.

Copy link
Contributor Author

@ProGamerGov ProGamerGov Jun 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've improved the documentation, added the type hints, and also provided a link to the nn.Identity class!

super().__init__()

def forward(
self, x: Union[torch.Tensor, Tuple[torch.Tensor]], *args, **kwargs
) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
"""
Args:
x (torch.Tensor or tuple of torch.Tensor): The input tensor or tensors.
args (Any): Any argument. Arguments will be safely ignored.
kwargs (Any) Any keyword argument. Arguments will be safely ignored.
Returns:
x (torch.Tensor or tuple of torch.Tensor): The unmodified input tensor or
tensors.
"""
return x


Expand Down
2 changes: 1 addition & 1 deletion captum/optim/models/_image/inception_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

GS_SAVED_WEIGHTS_URL = (
"https://github.com/pytorch/captum/raw/"
+ "optim-wip/captum/optim/_models/inception5h.pth"
+ "optim-wip/captum/optim/models/_image/inception5h.pth"
)


Expand Down
9 changes: 4 additions & 5 deletions scripts/install_via_conda.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,14 @@ else
fi

# install other deps
conda install -y numpy sphinx pytest flake8 ipywidgets ipython
conda install -y -c conda-forge black matplotlib pytest-cov sphinx-autodoc-typehints mypy flask isort
conda install -y numpy sphinx pytest flake8 ipywidgets ipython scikit-learn
conda install -y -c conda-forge black matplotlib pytest-cov sphinx-autodoc-typehints mypy flask isort flask-compress

# install node/yarn for insights build
conda install -y -c conda-forge yarn
# nodejs should be last, otherwise other conda packages will downgrade node
conda update -y --no-channel-priority -c conda-forge nodejs
conda install -y --no-channel-priority -c conda-forge nodejs=14


# build insights and install captum
# TODO: remove CI=false when we want React warnings treated as errors
CI=false BUILD_INSIGHTS=1 python setup.py develop
BUILD_INSIGHTS=1 python setup.py develop
2 changes: 1 addition & 1 deletion tests/optim/core/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_input_optimization(self) -> None:
model = BasicModel_ConvNet_Optim()
loss_fn = opt.loss.ChannelActivation(model.layer, 0)
obj = opt.InputOptimization(model, loss_function=loss_fn)
n_steps = 5
n_steps = 25
history = obj.optimize(opt.optimization.n_steps(n_steps, show_progress=False))
self.assertTrue(history[0] > history[-1])
self.assertTrue(len(history) == n_steps)
Expand Down
20 changes: 1 addition & 19 deletions tests/optim/helpers/numpy_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,6 @@
import numpy as np


def setup_batch(x: np.ndarray, batch: int = 1, dim: int = 3) -> np.ndarray:
assert batch > 0
x = x[None, :] if x.ndim == dim and batch == 1 else x
x = (
np.stack([np.copy(x) for b in range(batch)])
if x.ndim == dim and batch > 1
else x
)
return x


class FFTImage:
"""Parameterize an image using inverse real 2D FFT"""

Expand Down Expand Up @@ -62,17 +51,10 @@ def __init__(
def rfft2d_freqs(height: int, width: int) -> np.ndarray:
"""Computes 2D spectrum frequencies."""
fy = np.fft.fftfreq(height)[:, None]
# on odd input dimensions we need to keep one additional frequency
wadd = 2 if width % 2 == 1 else 1
fx = np.fft.fftfreq(width)[: width // 2 + wadd]
fx = np.fft.fftfreq(width)[: width // 2 + 1]
return np.sqrt((fx * fx) + (fy * fy))

def set_image(self, correlated_image: np.ndarray) -> None:
coeffs = np.fft.rfftn(correlated_image, s=self.size).view("(2,)float")
self.fourier_coeffs = coeffs / self.spectrum_scale

def forward(self) -> np.ndarray:
h, w = self.size
scaled_spectrum = self.fourier_coeffs * self.spectrum_scale
scaled_spectrum = scaled_spectrum.astype(complex)
output = np.fft.irfftn(scaled_spectrum, s=self.size)
Expand Down
12 changes: 12 additions & 0 deletions tests/optim/models/test_models_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,18 @@ def test_skip_layer(self) -> None:
output_tensor = layer(x)
assertTensorAlmostEqual(self, x, output_tensor, 0)

def test_skip_layer_ignore_init_variables(self) -> None:
layer = model_utils.SkipLayer(0, inplace=True)
x = torch.randn(1, 3, 4, 4)
output_tensor = layer(x)
assertTensorAlmostEqual(self, x, output_tensor, 0)

def test_skip_layer_ignore_forward_variables(self) -> None:
layer = model_utils.SkipLayer()
x = torch.randn(1, 3, 4, 4)
output_tensor = layer(x, 1, inverse=True)
assertTensorAlmostEqual(self, x, output_tensor, 0)


class TestSkipLayersFunction(BaseTest):
def test_skip_layers(self) -> None:
Expand Down
Loading