# Preparing Models For Captum's Optim Module

While most models will work out of the box with the Optim module, some model may require a few minor changes for full compatibility. This tutorial demonstrates how to easily perform the suggested & required changes to models for use with the Optim module.

In [None]:
%load_ext autoreload
%autoreload 2

import captum.optim as opt
import torch
import torch.nn.functional as F

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Model Layer Changes

The Optim module's layer related functions, and optimization systems rely on layers being defined as `nn.Module` classes rather than functional layers. Specifically, Optim's loss optimization and activation collection rely on PyTorch's hook system via [`register_forward_hook`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=register_forward_hook#torch.nn.Module.register_forward_hook), and functional layers do not support hooks.
Other functions like `replace_layers` can only detect `nn.Module` objects inside models.


For the purpose of this tutorial, our main toy model does not use any functional layers. Though if you are wishing to use your own model then you should verify that all applicable functional layers have been changed to their `nn.Module` equivalents in your chosen model.

* A list of all PyTorch's `torch.nn.functional` layers can be found [here](https://pytorch.org/docs/stable/nn.functional.html), and each layer has links to their `nn.Module` equivalents.

* The most common change that you will likely encounter, is converting the functional [`F.relu`](https://pytorch.org/docs/stable/generated/torch.nn.functional.relu.html#torch.nn.functional.relu) layers to [`nn.ReLU`](https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html).

## Tutorial Setup

Below we define a simple toy model and a functional version of the toy model for usage in our examples.

In [None]:
class ToyModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.basic_module = torch.nn.Sequential(
            torch.nn.Conv2d(3, 4, kernel_size=3, stride=2),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.conv = torch.nn.Conv2d(4, 4, kernel_size=3, stride=2)
        self.bn = torch.nn.BatchNorm2d(4)
        self.relu = torch.nn.ReLU()
        self.pooling = torch.nn.AdaptiveAvgPool2d((2, 2))
        self.linear = torch.nn.Linear(16, 4)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.basic_module(x)
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.pooling(x)
        x = x.flatten()
        x = self.linear(x)
        return x


class ToyModelFunctional(torch.nn.Module):
    """Functional layer only version of our toy model"""

    def __init__(self) -> None:
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.conv2d(x, weight=torch.ones([4, 3, 3, 3]), kernel_size=3, stride=2)
        x = F.relu(x)
        x = F.max_pool2d(x, kernel_size=3, stride=2)

        x = F.conv2d(x, weight=torch.ones([4, 3, 3, 3]), kernel_size=3, stride=2)
        x = F.batch_norm(x, running_mean=torch.ones([4]), running_var=torch.ones([4]))
        x = F.relu(x)
        x = F.adaptive_avg_pool2d(input, (2, 2))
        x = x.flatten()
        x = F.linear(input, weight=torch.ones([4, 16]))
        return x

## The Basics: Targetable Layers

The optim module's `opt.models.collect_activations` function and loss objectives (`opt.loss.<LossObjective>`) rely on forward hooks using PyTorch's hook system. This means that functional layers cannot be used as optimization targets, and activations cannot be collected for them.

Models can easily be checked for compatible layers via the `opt.models.get_model_layers` function as we'll see below.

In [None]:
# Functional version of the toy model with no nn.Module layers
toy_model_functional = ToyModelFunctional().eval().to(device)

# Get hookable layers
possible_targets = opt.models.get_model_layers(toy_model_functional)

print("Possible targets:", possible_targets)

Possible targets: []


As you can see, no layers capable of being hooked were found in our functional layer model.

Below we use the `opt.models.get_model_layers` function to see a list of all the hookable layers in our non-functional model that we can use as targets.

In [None]:
# Toy model with only nn.Module layers
target_model = ToyModel().eval().to(device)

# Get hookable layers
possible_targets = opt.models.get_model_layers(target_model)

# Display hookable layers
print("Possible targets:")
for t in possible_targets:
    print("  target_model." + t)

Possible targets:
  target_model.basic_module
  target_model.basic_module[0]
  target_model.basic_module[1]
  target_model.basic_module[2]
  target_model.conv
  target_model.bn
  target_model.relu
  target_model.pooling
  target_model.linear


We can then easily use any of the targets found above for optimization and activation collection, as we show below.

In [None]:
target_model = ToyModel().eval().to(device)

# Set layer target
target_layer = target_model.conv

# Collect activations from target
activations_dict = opt.models.collect_activations(
    model=target_model, targets=target_layer
)

# Collect target from activations dict
activations = activations_dict[target_layer]

# Display activation shape
print("Output shape of the {} layer activations:".format(type(target_layer)))
print("  {} \n".format(activations.shape))

# We can also use the target for loss objectives
loss_fn = opt.loss.LayerActivation(target=target_layer)

# Print loss objective
print("Loss objective:", loss_fn)
print("  target:", loss_fn.target)

Output shape of the <class 'torch.nn.modules.conv.Conv2d'> layer activations:
  torch.Size([1, 4, 27, 27]) 

Loss objective: LayerActivation []
  target: Conv2d(4, 4, kernel_size=(3, 3), stride=(2, 2))


## Visualization: Redirected ReLU

In some cases, the target of interest may not be activated at all by the initial random input. If this happens, the zero derivative stops the gradient from flowing backwards and thus we never move towards any meaningful visualization. To solve this problem, we can replace the ReLU layers in a model with a special version of ReLU called `RedirectedReLU`. The `RedirectedReLU` layer allows the gradient to flow temporarily in these zero gradient situations.

Below we use the `opt.models.replace_layers` function to replace all instances of `nn.ReLU` in our toy model with `opt.models.RedirectedReluLayer`.

In [None]:
relu_model = ToyModel().eval().to(device)

# Replace the ReLU with RedirectedReluLayer
opt.models.replace_layers(
    relu_model, layer1=torch.nn.ReLU, layer2=opt.models.RedirectedReluLayer
)

# Show modified model
print(relu_model)

ToyModel(
  (basic_module): Sequential(
    (0): Conv2d(3, 4, kernel_size=(3, 3), stride=(2, 2))
    (1): RedirectedReluLayer()
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv): Conv2d(4, 4, kernel_size=(3, 3), stride=(2, 2))
  (bn): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): RedirectedReluLayer()
  (pooling): AdaptiveAvgPool2d(output_size=(2, 2))
  (linear): Linear(in_features=16, out_features=4, bias=True)
)


## Circuits: Linear Operation Layers

Certain functions like `opt.circuits.extract_expanded_weights` require using modules that only perform linear operations. This can become slightly more complicated when dealing with layers that have multiple preset set variables. Luckily the `opt.models.replace_layers` function can easily handle these variable transfers for layer types like pooling layers if the `transfer_vars` variable is set to `True`.


Common linear layer replacements are as follows:

* `nn.ReLU` layers need to be skipped, which can be done by replacing them with either `nn.Identity` or Captum's `SkipLayer` layer.

* `nn.MaxPool2d` layers need to be converted to their linear `nn.AvgPool2d` layer equivalents.

* `nn.AdaptiveMaxPool2d` layers need to be converted to their linear `nn.AdaptiveAvgPool2d` layer equivalents.

Some of the layers which are already linear operations are:

* `nn.BatchNorm2d` is linear when it's in evaluation mode (`.eval()`).
* `nn.Conv2d` is linear.
* `nn.Linear` is linear.

In [None]:
linear_only_model = ToyModel().eval().to(device)

# Replace MaxPool2d with AvgPool2d using the same settings
opt.models.replace_layers(
    linear_only_model,
    layer1=torch.nn.MaxPool2d,
    layer2=torch.nn.AvgPool2d,
    transfer_vars=True,  # Use same MaxPool2d parameters for AvgPool2d
)

# Replace ReLU with Identity
opt.models.replace_layers(
    linear_only_model, layer1=torch.nn.ReLU, layer2=torch.nn.Identity
)

# Show modified model
print(linear_only_model)

ToyModel(
  (basic_module): Sequential(
    (0): Conv2d(3, 4, kernel_size=(3, 3), stride=(2, 2))
    (1): Identity()
    (2): AvgPool2d(kernel_size=3, stride=2, padding=0)
  )
  (conv): Conv2d(4, 4, kernel_size=(3, 3), stride=(2, 2))
  (bn): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): Identity()
  (pooling): AdaptiveAvgPool2d(output_size=(2, 2))
  (linear): Linear(in_features=16, out_features=4, bias=True)
)


## Other: Relaxed Pooling

Some attribution based operations like those used in activation atlas sample collection, require replacing the `nn.MaxPool2d` layers with a special relaxed version called `MaxPool2dRelaxed`. This is also extremely easy to do with the `replace_layers` function like we did above.

In [None]:
relaxed_pooling_model = ToyModel().eval().to(device).basic_module

# Replace MaxPool2d with MaxPool2dRelaxed
opt.models.replace_layers(
    relaxed_pooling_model,
    torch.nn.MaxPool2d,
    opt.models.MaxPool2dRelaxed,
    transfer_vars=True,  # Use same MaxPool2d parameters for MaxPool2dRelaxed
)

# Show modified model
print(relaxed_pooling_model)

Sequential(
  (0): Conv2d(3, 4, kernel_size=(3, 3), stride=(2, 2))
  (1): ReLU()
  (2): MaxPool2dRelaxed(
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (avgpool): AvgPool2d(kernel_size=3, stride=2, padding=0)
  )
)
