# Training Linear Probes For Faceted Feature Visualization

This tutorial demonstrates how to train linear probes for use in faceted feature visualization, as described in the Faceted Feature Visualization section of the Multimodal Neurons in Artificial Neural Networks research paper [here](https://distill.pub/2021/multimodal-neurons/#faceted-feature-visualization).

In [None]:
%load_ext autoreload
%autoreload 2

import copy
import time
from collections import Counter
from typing import Dict, List, Optional, Tuple, Union

import captum.optim as opt
import torch
import torchvision

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

## Setup

Before we can start training the linear probes, we'll need to do a bit of setup first. Below we define a helper function for balancing the classes of image datasets, and an optional transform that pads input images to squares for datasets requiring more spatial similarity.

In [None]:
def balance_training_classes(
    dataloader: torch.utils.data.DataLoader, num_classes: int = 2
) -> List[float]:
    """
    Calculate balancing weights for a given dataloader instance.

    Args:

        dataloader (torch.utils.data.DataLoader): A dataloader instance to count the
            number of images in each class for.
        num_classes (int, optional): The number of classes used in the dataset.
            Default: 2

    Returns:
        weights (list of float): A list of values for balancing the classes.
    """
    train_class_counts = dict(
        Counter(sample_tup[1] for sample_tup in dataloader.dataset)
    )
    train_class_counts = dict(sorted(train_class_counts.items()))
    train_weights = [
        1.0 / train_class_counts[class_id] for class_id in range(num_classes)
    ]
    return train_weights


class PadToSquare(torch.nn.Module):
    """
    Transform for padding rectangular shaped inputs to squares without messing up the
    aspect ratio.
    """

    __constants__ = ["padding_value"]

    def __init__(self, padding_value: float = 0.0) -> None:
        """
        Args:

            padding_value (float, optional): The value to use for the constant
                padding.
                Default: 0.0
        """
        super().__init__()
        self.padding_value = padding_value

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        assert x.dim() == 4 or x.dim() == 3
        if x.dim() == 4:
            C, H, W = x.shape[1:]
        elif x.dim() == 3:
            C, H, W = x.shape
        top, left = [(max(H, W) - d) // 2 for d in [H, W]]
        bottom, right = [max(H, W) - (d + pad) for d, pad in zip([H, W], [top, left])]

        padding = [left, right, top, bottom]
        if x.dim() == 3:
            return torch.nn.functional.pad(
                x[None, :], padding, value=self.padding_value, mode="constant"
            )[0]
        else:
            return torch.nn.functional.pad(
                x, padding, value=self.padding_value, mode="constant"
            )


def get_dataset_indices(dataset_path: str) -> Dict[str, int]:
    """
    If you are not sure what the class indices are for your training images & the
    generic natural images, then you can use this handy helper function that
    replicates the ordering used by `torchvision.datasets.ImageFolder`.

    Args:

        dataset_path (str): The path to your image dataset that is using the standard
            ImageFolder structure.


    Returns
        class_and_idx (dict of str and int): The folder names and corresponding class
            indices.
    """
    import os

    classes = [d.name for d in os.scandir(dataset_path) if d.is_dir()]
    classes.sort()
    return {cls_name: i for i, cls_name in enumerate(classes)}

### Dataset Setup


For the purpose of this tutorial we demonstrate setting up a basic dataset utilizing Torchvision's [ImageFolder](https://pytorch.org/vision/stable/_modules/torchvision/datasets/folder.html#ImageFolder). However you can use whatever dataset you like, provided of course it works with [`torch.utils.data.DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader), otherwise you may have to modify the training function to support your dataset.

The authors of the research paper recommend that image datasets should contain a minimum of 2 classes, where one class is composed of generic natural images and the other class or classes contain the desired themes / concepts. The basic idea behind the image dataset class structure is to train the model to separate out a theme / concept from unrelated stuff.

**Spatial information in your dataset**

In the research paper, the authors trained some of the facets on images where the features in each image in the dataset were in roughly the same locations. This is important to note only if you are trying to create similar facets where you want more spatially coherent shapes like those of the `face` facet used in other tutorials.

In [None]:
def create_dataloaders(
    dataset_path: str,
    batch_size: int = 32,
    val_percent: float = 0.0,
    training_transforms: torch.nn.Module = None,
    validation_transforms: Optional[torch.nn.Module] = None,
    balance_classes: bool = False,
    num_classes: int = 2,
) -> Dict[str, Union[torch.utils.data.DataLoader, List[float]]]:
    """
    Create one or more dataloader instances with optional balancing weights for a
    given image dataset, with Torchvision's ImageFolder directory format.

    https://pytorch.org/vision/stable/_modules/torchvision/datasets/folder.html#ImageFolder

    Args:

        dataset_path (str): The path to the image dataset to use for torchvision's
            ImageFolder dataset. See above for more details.
        batch_size (int, optional): The batch size to use.
            Default: 32
        val_percent (float, optional): The percentage of the dataset to use for
            validation. If set to 0 then no validation dataset will be created.
            Default: 0.0
        training_transforms (nn.Module): Transforms to use for training the linear
            probes.
        validation_transforms (nn.Module, optional): Transforms to use for validation,
            if validation is enabled.
        balance_classes (bool, optional): Whether or not to calculate weights for
            balancing the training classes.
            Default: False
        num_classes (int, optional): If balance_classes is set to True, then this
            variable provides the number of classes in the dataset to use in the
            balancing calculations.
            Default: 2

    Returns:
        dataloaders (dict of dataloader and list of float): A dictionary containing
            the training dataloader, with optional validation dataloader and balancing
            weights for the training dataloader.
    """
    full_dataset = torchvision.datasets.ImageFolder(
        root=dataset_path,
    )

    if val_percent > 0.0:
        assert validation_transforms is not None
        n = len(full_dataset)
        lengths = [round(n * (1 - val_percent)), round(n * val_percent)]

        t_data, v_data = torch.utils.data.random_split(full_dataset, lengths)
        t_data = copy.deepcopy(t_data)

        t_data.dataset.transform = training_transforms
        v_data.dataset.transform = validation_transforms

        t_dataloader = torch.utils.data.DataLoader(
            t_data,
            batch_size=batch_size,
            shuffle=True,
        )
        v_dataloader = torch.utils.data.DataLoader(
            v_data, batch_size=batch_size, shuffle=True
        )
        dataloader = {"train": t_dataloader, "val": v_dataloader}
    else:
        t_dataset = torch.utils.data.Subset(
            copy.deepcopy(full_dataset), range(0, len(full_dataset))
        )
        t_dataset.dataset.transform = training_transforms
        t_dataloader = torch.utils.data.DataLoader(
            t_dataset, batch_size=batch_size, shuffle=True
        )
        dataloader = {"train": t_dataloader}

    if balance_classes:
        train_weights = balance_training_classes(dataloader["train"], num_classes)
        dataloader["train_weights"] = train_weights
    return dataloader

### Training Function

The model training function's `dataloaders` variable requires training dataloaders to be organized in into dictionaries containing the following keys and values:

* `train`: The training dataloader.
* `val`: Optionally include validation dataloader. If this key doesn't exist in the dict, then no validation phase will be performed.
* `train_weights`: Optionally include a list of training weights to balance the classes during training.


Linear probes are implemented as [`nn.LazyLinear`](https://pytorch.org/docs/stable/generated/torch.nn.LazyLinear.html) layers with a reshaping operation between them and the target layer.

In [None]:
def train_linear_probes(
    model: torch.nn.Module,
    target_layers: List[torch.nn.Module],
    dataloaders: Dict[str, Union[torch.utils.data.DataLoader, List[float]]],
    out_features: int = 2,
    num_epochs: int = 10,
    lr: float = 1.0,
    l1_weight: float = 0.0,
    l2_weight: float = 0.0,
    use_optimizer: str = "lbfgs",
    device: torch.device = torch.device("cpu"),
    save_epoch: Optional[int] = None,
    save_path: str = "epoch_",
    verbose: bool = True,
    show_progress: bool = False,
) -> Tuple[List[torch.Tensor]]:
    """
    Train linear probes on target layers of a specified model, for use as faceted
    feature visualization facet weights.

    Args:

        model (nn.Module): An PyTorch model instance.
        target_layers (nn.Module): A list of model targets to train linear probes for.
        dataloaders (dict of torch.utils.data.DataLoader): A dictionary of PyTorch
            Dataloader instances for training and optionally for validation.
        num_epochs (int, optional): The number of epochs to train for.
            Default: 10
        l1_weight (float, optional): The desired l1 penalty weight to use.
            Default: 0.0
        l2_weight (float, optional): The desired l2 penalty weight to use.
            Default: 0.0
        lr (float, optional): The desired learning rate to use with the optimizer.
            Default: 1.0
        use_optimizer (str, optional): The optimizer to use. Choices are: "sgd" or
            "lbfgs".
            Default: "lbfgs"
        device (torch.device, optional): The device to place training inputs on before
            sending them through the model.
            Default: torch.device("cpu")
        save_epoch (int, optional): Save the best model weights every save_epoch
            epochs. Set to None to not save any epochs.
            Default: None
        save_path (str, optional): If save_epoch is not None, save model weights with
            the path / name: <save_path + epoch + ".pt">.
            Default: "epoch_"
        verbose (bool, optional): Whether or not to print loss and accuracy after
            every epoch.
            Default: True

    Returns:
        weights (list of torch.Tensor): The weights of the best scoring models from
            the training session. The order of the weights corresponds to
            `target_layers`.
        best_acc (list of float): The training accuracies for the returned weights.
            The order corresponds to `weights`.
    """
    assert use_optimizer in ["lbfgs", "sgd"]
    assert "train" in dataloaders

    phases = ["train", "val"] if "val" in dataloaders else ["train"]

    # Optionally balance classes if provided with weight balancing tensor
    if "train_weights" in dataloaders:
        crit_weights = torch.FloatTensor(dataloaders["train_weights"])
        criterion = torch.nn.CrossEntropyLoss(weight=crit_weights).to(device)
    else:
        criterion = torch.nn.CrossEntropyLoss()

    # Create Linear Probes using LazyLinear so that we don't need to specify an input size
    layer_probes = [
        torch.nn.LazyLinear(out_features, bias=False).to(device).train()
        for _ in target_layers
    ]
    num_probes = len(target_layers)

    # Setup model saving
    best_models = [None for _ in layer_probes]
    best_accs = [0.0] * num_probes

    # Setup optimizer
    parameters = []
    for p in layer_probes:
        parameters += list(p.parameters())
    if use_optimizer == "lbfgs":
        optimizer = torch.optim.LBFGS(
            parameters, lr=lr, max_iter=1, tolerance_change=-1, tolerance_grad=-1
        )
    else:
        optimizer = torch.optim.SGD(parameters, lr=lr, momentum=0.0, weight_decay=0.0)

    # Get dataset lengths beforehand to speed things up
    val_length = 0 if "val" not in dataloaders else len(dataloaders["val"].dataset)
    dataset_length = {"train": len(dataloaders["train"].dataset), "val": val_length}

    start_time = time.time()
    for epoch in range(num_epochs):
        if verbose:
            print("Epoch {}/{}".format(epoch + 1, num_epochs))
            print("-" * 12)

        for phase in phases:
            if phase == "train":
                [layer_probes[i].train() for i in range(num_probes)]
            else:
                [layer_probes[i].eval() for i in range(num_probes)]

            phase_stats = {
                "epoch_acc": [0.0] * num_probes,
                "epoch_loss": [0.0] * num_probes,
            }

            for inputs, labels in dataloaders[phase]:
                inputs, labels = inputs.to(device), labels.to(device)

                with torch.set_grad_enabled(phase == "train"):
                    if use_optimizer == "lbfgs":
                        # Training with torch.optim.LBFGS

                        def closure() -> torch.Tensor:
                            optimizer.zero_grad()
                            # Collect outputs for target layers
                            probe_inputs = opt.models.collect_activations(
                                model, target_layers, inputs
                            )
                            outputs = [probe_inputs[target] for target in target_layers]

                            # Send layer outputs through linear probes
                            outputs = [
                                probe(x.reshape(x.shape[0], -1))
                                for x, probe in zip(outputs, layer_probes)
                            ]

                            probe_losses = [
                                criterion(outputs[i], labels) for i in range(num_probes)
                            ]
                            preds = [
                                torch.max(outputs[i], 1)[1] for i in range(num_probes)
                            ]
                            loss = sum(probe_losses)

                            if phase == "train":

                                # Apply optional L1 or L2 penalties
                                if l1_weight != 0.0 or l2_weight != 0.0:
                                    if l1_weight != 0.0:
                                        l1_penalty = sum(
                                            [
                                                l1_weight * p.weight.abs().sum()
                                                for p in layer_probes
                                            ]
                                        )
                                        loss = loss + l1_penalty
                                    if l2_weight != 0.0:
                                        l2_penalty = l2_weight * sum(
                                            [
                                                (p.weight**2).sum()
                                                for p in layer_probes
                                            ]
                                        )
                                        loss = loss + l2_penalty

                                loss.backward()

                            with torch.no_grad():
                                phase_stats["epoch_loss"] = [
                                    phase_stats["epoch_loss"][i]
                                    + l.detach().item() * inputs.size(0)
                                    for i, l in enumerate(probe_losses)
                                ]
                                phase_stats["epoch_acc"] = [
                                    phase_stats["epoch_acc"][i]
                                    + torch.sum(p == labels).detach().item()
                                    for i, p in enumerate(preds)
                                ]
                            return loss

                        optimizer.step(closure)
                    else:
                        # Training with torch.optim.SGD

                        optimizer.zero_grad()
                        # Collect outputs for target layers
                        probe_inputs = opt.models.collect_activations(
                            model, target_layers, inputs
                        )
                        outputs = [probe_inputs[target] for target in target_layers]

                        # Send layer outputs through linear probes
                        outputs = [
                            probe(x.reshape(x.shape[0], -1))
                            for x, probe in zip(outputs, layer_probes)
                        ]

                        probe_losses = [
                            criterion(outputs[i], labels)
                            for i in range(len(layer_probes))
                        ]
                        preds = [
                            torch.max(outputs[i], 1)[1]
                            for i in range(len(layer_probes))
                        ]

                        loss = sum(probe_losses)

                        if phase == "train":

                            # Apply optional L1 or L2 penalties
                            if l1_weight != 0.0:
                                l1_penalty = sum(
                                    [
                                        l1_weight * p.weight.abs().sum()
                                        for p in layer_probes
                                    ]
                                )
                                loss = loss + l1_penalty
                            if l2_weight != 0.0:
                                l2_penalty = l2_weight * sum(
                                    [(p.weight**2).sum() for p in layer_probes]
                                )
                                loss = loss + l2_penalty

                            loss.backward()
                            optimizer.step()

                        with torch.no_grad():
                            phase_stats["epoch_loss"] = [
                                phase_stats["epoch_loss"][i]
                                + l.detach().item() * inputs.size(0)
                                for i, l in enumerate(probe_losses)
                            ]
                            phase_stats["epoch_acc"] = [
                                phase_stats["epoch_acc"][i]
                                + torch.sum(p == labels).detach().item()
                                for i, p in enumerate(preds)
                            ]

            phase_stats["epoch_loss"] = [
                phase_stats["epoch_loss"][i] / dataset_length[phase]
                for i in range(num_probes)
            ]
            phase_stats["epoch_acc"] = [
                phase_stats["epoch_acc"][i] / dataset_length[phase]
                for i in range(num_probes)
            ]

            # Make sure we keep the best model weights
            if phase == "val" or "val" not in phases:
                for i, acc in enumerate(phase_stats["epoch_acc"]):
                    if acc > best_accs[i]:
                        best_accs[i] = acc
                        best_models[i] = layer_probes[i].weight.clone().detach().cpu()

            if verbose:
                print(
                    "{} Loss: {:.4f} Acc: {:.4f}".format(
                        phase,
                        sum(phase_stats["epoch_loss"]) / num_probes,
                        sum(phase_stats["epoch_acc"]) / num_probes,
                    )
                )
                print("  Loss: ", [round(v, 4) for v in phase_stats["epoch_loss"]])
                print("  Acc: ", [round(acc, 4) for acc in phase_stats["epoch_acc"]])
                time_elapsed = time.time() - start_time
                print(
                    "Time Elapsed {:.0f}m {:.0f}s".format(
                        time_elapsed // 60, time_elapsed % 60
                    )
                )
                if epoch + 1 != num_epochs:
                    print()

        if save_epoch and (epoch + 1) % save_epoch == 0 and (epoch + 1) != num_epochs:
            facet_weights = [w.clone().cpu().detach() for w in best_models]
            filename = save_path + str(epoch + 1) + ".pt"
            torch.save([w.cpu() for w in facet_weights], filename)

    return best_models, best_accs

### Load Model & Dataset

Now that we have the required classes and functions defined, we load the ResNet 50x4 image model without `RedirectedReLU`.

In [None]:
# Load image model
clip_model = (
    opt.models.clip_resnet50x4_image(
        pretrained=True, replace_relus_with_redirectedrelu=False
    )
    .eval()
    .to(device)
)

Next we load our dataset's dataloaders for training. Remember that our dataloader creation function uses Torchvision's ImageFolder, and thus different datasets may need their own setup functions.

In [None]:
dataset_path = "my_dataset"  # Path to dataset
num_classes = 2 # Number of classes in our dataset

# Setup transforms for training
training_transforms = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        # PadToSquare(1.0),
        torchvision.transforms.Resize((288, 288), antialias=True),
    ]
)

dataloaders = create_dataloaders(
    dataset_path,
    batch_size=16,
    val_percent=0.0,
    training_transforms=training_transforms,
    balance_classes=True,
    num_classes=num_classes,
)

## Training The Linear Probes

We can now begin training the linear probes on the target layers! Below we train linear probes on the same 5 lower layers as the researchers did in the paper.

Note that using the [L-BFGS optimizer](https://pytorch.org/docs/stable/generated/torch.optim.LBFGS.html) will generally produce the best quality facets, but it will also use more memory than the [SGD optimizer](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html). Memory usage can also be reduced by training fewer linear probes at once.

Note that you may have to adjust the default parameters for training for custom datasets and models.

In [None]:
# Layers to train linear probes for
target_layers = [
    clip_model.layer3[0].relu3,
    clip_model.layer3[2].relu3,
    clip_model.layer3[4].relu3,
    clip_model.layer3[6].relu3,
    clip_model.layer3[8].relu3,
]


# The L-BFGS optimizer will use more memory than the SGD optimizer
use_optimizer = "lbfgs" # Whether to optimize with "lbfgs" or "sgd"

# Optimizer specific param setup
if use_optimizer == "lbfgs":
    l2_weight = 0.0
    lr = 1.0
else:
    l2_weight = 0.316
    lr = 0.0001

# Train linear probes
weights, weight_accs = train_linear_probes(
    model=clip_model,
    target_layers=target_layers,
    dataloaders=dataloaders,
    # This should be the same as the number of classes in the dataset
    out_features=num_classes,
    num_epochs=5,
    lr=lr,
    l2_weight=l2_weight,
    use_optimizer=use_optimizer,
    device=device,
)

Epoch 1/5
------------
train Loss: 390337.9189 Acc: 0.9715
  Loss:  [56043.4749, 1363915.4473, 124310.3623, 168846.0195, 238574.2905]
  Acc:  [0.9718, 0.966, 0.9722, 0.9705, 0.9771]
Time Elapsed 3m 14s

Epoch 2/5
------------
train Loss: 16781.2769 Acc: 0.9976
  Loss:  [14076.3319, 31218.2309, 6106.3447, 19327.1426, 13178.3344]
  Acc:  [0.9958, 0.9979, 0.9986, 0.9969, 0.999]
Time Elapsed 6m 31s

Epoch 3/5
------------
train Loss: 329.2152 Acc: 0.9994
  Loss:  [689.9083, 327.7661, 481.1846, 147.2171, 0.0]
  Acc:  [0.9982, 0.9997, 0.9994, 0.9994, 1.0]
Time Elapsed 9m 48s

Epoch 4/5
------------
train Loss: 468.3097 Acc: 0.9989
  Loss:  [546.3372, 485.5594, 319.5212, 988.2269, 1.9037]
  Acc:  [0.9987, 0.999, 0.9993, 0.9978, 0.9999]
Time Elapsed 13m 5s

Epoch 5/5
------------
train Loss: 100.6919 Acc: 0.9997
  Loss:  [236.6766, 138.6808, 78.6038, 49.4981, 0.0]
  Acc:  [0.9994, 0.9997, 0.9997, 0.9997, 1.0]
Time Elapsed 16m 21s


Now that we have our trained weights, we can slice out the batch dimensions that correspond to the predicted theme / concept that we are training on while ignoring the batch dimension for the generic natural images. For this tutorial we were only training 1 class in addition to the generic natural images, so we only have one index of weights to collect.

In [None]:
# Uncomment to get dataset class indices for ImageFolder datasets
# print(get_dataset_indices(dataset_path))

In [None]:
# We only need the theme / concept part of the weights
theme_idx = 0 # Class idx for the target theme / concept
facet_weights = [w[theme_idx : theme_idx + 1] for w in weights]

The `nn.LazyLinear` layers used to train the probes require 2D inputs, and thus 4D layer targets like `nn.Conv2d` layers need to be reshaped back to their 4D output shapes after training. For this tutorial, all layer targets have an output shape of: `[N, 1280, 18, 18]`.

In [None]:
# Uncomment to view the shape of each layer
# out_dict = opt.models.collect_activations(
#     clip_model, target_layers, torch.zeros(1, 3, 288, 288)
# )
# print([out_dict[t].shape for t in target_layers])

In [None]:
# Each probe weight can be reshaped to match its corresponding model layer
facet_weights = [w.reshape(1, 1280, 18, 18) for w in facet_weights]

We can now save our facet weights as they are ready for use in faceted feature visualization!

In [None]:
# Save the trained weights
torch.save([w.cpu() for w in facet_weights], "my_facet_weights.pt")

# Then the weights can be loaded like this
# facet_weights = torch.load("my_facet_weights.pt")

If you trained multiple facet themes at once, then you can save them individually like in the example code below.

In [None]:
# Uncomment to save multiple facets
# theme_indices = [0, 1]
# for idx in theme_indices:
#     facet_weights = [w[idx : idx + 1].reshape(1, 1280, 18, 18) for w in weights]
#     torch.save(
#         [w.cpu() for w in facet_weights], "my_facet_weights_{}_.pt".format(idx)
#     )

The facet weights can then be loaded and used for the `FacetLoss` objective's required `facet_weights` variable.