# Demo: Non-identifiability for a larger network (MNIST dataset)

In this notebook, we use the library to demonstrate the non-identifiability of MI criteria for circuit detection at larger scales.

To do so, we train a much larger MLP on a subset of MNIST, filtered to contain only images labeled with 0 or 1. This simplified problem allows us to focus on extracting circuits within the last layers of this network, which have a reduced size and are therefore tractable for exhaustive enumeration.


### Install/import packages and set parameters

In [1]:
# Uncomment this line to install the library if running on Colab
# !pip install git+https://github.com/MelouxM/MI-identifiability.git

In [2]:
import numpy as np
import torch
from torch import optim, nn
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from tqdm import tqdm

from mi_identifiability.neural_model import MLP
from mi_identifiability.circuit import find_circuits
from mi_identifiability.utils import set_seeds

# Training parameters
batch_size = 64
learning_rate = 1e-3
epochs = 10

# MLP parameters
input_size = 28 * 28  # MNIST images are 28x28
hidden_sizes = [128, 128, 3, 3, 3]  # Example architecture provided
output_size = 1  # Regression task with two possible outputs (0 or 1)
seed = 42

set_seeds(seed)

### Create training and validation datasets

We download the full MNIST dataset and filter it to only include images labeled as 0 or 1, then partition it into a training and a validation set.

In [3]:
def get_mnist_data_01():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))  # Normalize between -1 and 1
    ])

    # Download the full dataset
    mnist_train = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
    mnist_val = datasets.MNIST(root="./data", train=False, download=True, transform=transform)

    # Filter only the digits 0 and 1
    idx_train = (mnist_train.targets == 0) | (mnist_train.targets == 1)
    idx_val = (mnist_val.targets == 0) | (mnist_val.targets == 1)

    # Subset datasets
    train_subset = Subset(mnist_train, torch.where(idx_train)[0])
    val_subset = Subset(mnist_val, torch.where(idx_val)[0])

    # Convert targets to float and ensure they are 1D
    train_targets = torch.tensor([label for _, label in train_subset]).float()
    val_targets = torch.tensor([label for _, label in val_subset]).float()

    # Create a new dataset class that combines images and targets
    class MNISTSubset(torch.utils.data.Dataset):
        def __init__(self, subset, targets):
            self.subset = subset
            self.targets = targets

        def __getitem__(self, index):
            x, _ = self.subset[index]  # Get the image
            y = self.targets[index]  # Get the corresponding target
            return x, y

        def __len__(self):
            return len(self.subset)

    # Create new subsets
    train_dataset = MNISTSubset(train_subset, train_targets)
    val_dataset = MNISTSubset(val_subset, val_targets)

    return train_dataset, val_dataset

### Evaluation and training loops

The evaluation loop is a standard one. Since our model is an MLP rather than a CNN, each image needs to be flattened into a 1D vector (resulting in a 2D tensor for the entire batch).

In [4]:
def evaluate_model(model, val_loader, criterion, epoch):
    model.eval()
    val_loss = 0.0

    with torch.no_grad():
        for x_val, y_val in val_loader:
            x_val = x_val.view(x_val.size(0), -1)  # Flatten to 2D for MLP

            outputs = model(x_val)

            loss = criterion(outputs, y_val.view(-1, 1))
            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_loader)

    print(f"Validation Loss after Epoch [{epoch + 1}]: {avg_val_loss:.4f}")

We similarly define a function for training the model:

In [5]:
def train_model(model, train_loader, val_loader, learning_rate, epochs):
    model.train()

    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.MSELoss()

    for epoch in range(epochs):
        epoch_loss = 0.0
        for batch_idx, (x_batch, y_batch) in enumerate(tqdm(train_loader, desc=f'Epoch {epoch + 1}/{epochs}', unit='batch')):
            x_batch = x_batch.view(x_batch.size(0), -1)  # Flatten to 2D for MLP

            optimizer.zero_grad()
            outputs = model(x_batch)
            loss = criterion(outputs, y_batch.view(-1, 1))
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        evaluate_model(model, val_loader, criterion, epoch)

We can now download the data and create and train the model:

In [6]:
train_dataset, val_dataset = get_mnist_data_01()
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

model = MLP(hidden_sizes=hidden_sizes, input_size=input_size, output_size=output_size)
train_model(model, train_loader, val_loader, learning_rate, epochs)

Epoch 1/10: 100%|██████████| 198/198 [00:06<00:00, 32.60batch/s]


Validation Loss after Epoch [1]: 0.2126


Epoch 2/10: 100%|██████████| 198/198 [00:04<00:00, 45.35batch/s]


Validation Loss after Epoch [2]: 0.1026


Epoch 3/10: 100%|██████████| 198/198 [00:03<00:00, 51.46batch/s]


Validation Loss after Epoch [3]: 0.0032


Epoch 4/10: 100%|██████████| 198/198 [00:04<00:00, 48.82batch/s]


Validation Loss after Epoch [4]: 0.0016


Epoch 5/10: 100%|██████████| 198/198 [00:04<00:00, 48.44batch/s]


Validation Loss after Epoch [5]: 0.0034


Epoch 6/10: 100%|██████████| 198/198 [00:02<00:00, 76.48batch/s]


Validation Loss after Epoch [6]: 0.0026


Epoch 7/10: 100%|██████████| 198/198 [00:03<00:00, 60.77batch/s] 


Validation Loss after Epoch [7]: 0.0022


Epoch 8/10: 100%|██████████| 198/198 [00:02<00:00, 67.73batch/s]


Validation Loss after Epoch [8]: 0.0016


Epoch 9/10: 100%|██████████| 198/198 [00:03<00:00, 58.10batch/s]


Validation Loss after Epoch [9]: 0.0013


Epoch 10/10: 100%|██████████| 198/198 [00:04<00:00, 39.99batch/s]


Validation Loss after Epoch [10]: 0.0014


### Extracting circuits

We are interested in interpreting a submodel consisting in the last three layers of the model. To do so, we need to obtain the inputs of this model, or in other words, record the activations of the previous layer. We do so and store the pre-activations and labels for the entire dataset using the following helper function.

In [7]:
def collect_activations(submodel, data_loader):
    submodel.eval()
    all_activations = []
    all_labels = []

    inputs_all = []
    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs = inputs.view(inputs.size(0), -1)  # Flatten images to vectors
            inputs_all.append(inputs)
            layer_activations = submodel(inputs, return_activations=True)[-1]
            all_activations.append(layer_activations)
            all_labels.append(labels)

    # Concatenate activations and labels
    all_activations = np.concatenate(all_activations, axis=0)
    all_inputs = np.concatenate(inputs_all, axis=0)
    all_labels = torch.cat(all_labels, dim=0).unsqueeze(-1)

    return torch.tensor(all_inputs, dtype=torch.float32), torch.tensor(all_activations, dtype=torch.float32), all_labels

We then split the model into two parts, with the shallower part used to record activations and the deeper part used for circuit extraction.

In [8]:
first_layers = model[:-3]
last_layers = model[-3:]

x_val, x_val_h, y_val_h = collect_activations(first_layers, val_loader)

model.eval()
last_layers.eval()

MLP(
  (layers): ModuleList(
    (0-1): 2 x Sequential(
      (0): Linear(in_features=3, out_features=3, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
    )
    (2): Sequential(
      (0): Linear(in_features=3, out_features=1, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
    )
  )
)

We then perform a consistency check, making sure that the same outputs (predictions) are obtained when running the original model on the validation data or running the deeper model on the computed pre-activations.

In [9]:
with torch.no_grad():
    original_predictions = model(x_val)
    smaller_predictions = last_layers(x_val_h)

    rounded_orig = torch.round(original_predictions)
    rounded_small = torch.round(smaller_predictions)

    correct_predictions_orig = rounded_orig.eq(y_val_h).all(dim=1)
    accuracy_orig = correct_predictions_orig.sum().item() / y_val_h.size(0)
    print(f"Accuracy orig: {accuracy_orig}")

    correct_predictions_small = rounded_small.eq(y_val_h).all(dim=1)
    accuracy_small = correct_predictions_small.sum().item() / y_val_h.size(0)
    print(f"Accuracy small: {accuracy_small}")

predictions_equal = torch.allclose(original_predictions, smaller_predictions, atol=1e-6)

if predictions_equal:
    print("The predictions from the original MLP and the smaller MLP are the same.")
else:
    print("The predictions from the original MLP and the smaller MLP differ.")

Accuracy orig: 0.9985815602836879
Accuracy small: 0.9985815602836879
The predictions from the original MLP and the smaller MLP are the same.


Finally, we perform the circuit search on the deeper (and smaller) model.

In [10]:
top_sk, spar, df = find_circuits(last_layers, x_val_h, y_val_h, accuracy_threshold=0.99)
print(f"Number of circuits: {len(top_sk)}")

Number of circuits: 4702


The results imply that multiple incompatible where-then-what explanations exist for the larger model. Indeed, while we cannot enumerate explanations for the shallower part of the model, we know that it may allow for either no valid explanations or at least one.

If the shallower model has no valid explanation, then no explanation can be found for the entire model. If it has one (or more), then the grounding of the output of this explanation can be taken as the input to the deeper model. This means that each explanation of the shallower model would produce at least one valid grounding (and likely many more) for the entire model for each circuit.

As a result, the entire model has either no valid explanation, or at least as many as there are circuits in the deeper model.