# 4. Image Classification: Interpreting Visual Features

## 4.1. Qualitative Assessment: Top Eigenvectors Appear Interpretable

### NOTES

This notebook uses code from the [original author’s repository](https://github.com/tdooms/bilinear-interp) to reproduce results.  
We refer to [this code](https://github.com/tdooms/bilinear-interp/blob/main/_old/workspace/paper/features.py) and [this code](https://github.com/tdooms/bilinear-interp/blob/main/_old/workspace/paper/noise.py) for the implementation used in this notebook.

In [1]:
# Add root to the sys.path

import sys
from pathlib import Path

ROOT = Path().resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))

In [2]:
# Import libraries and enable autoreload

%load_ext autoreload
%autoreload 2

import torch
import plotly.express as px
from kornia.augmentation import RandomGaussianNoise

from image import FMNIST, LinearModel, MNIST, Model, plot_eigenspectrum

device = "cpu"

  from .autonotebook import tqdm as notebook_tqdm


Here is the training hyperparameters for image classification model, explained in the Appendix G in the paper:

| Parameter        | Value            |
| ---------------- | ---------------- |
| Input Noise Norm | 0.5              |
| Weight Decay     | 1.0              |
| Learning Rate    | 0.001            |
| Batch Size       | 2048             |
| Optimizer        | AdamW            |
| Schedule         | Cosine Annealing |
| Epochs           | 20–100           |
| Hidden Dimension | 512              |

The following cell uses these hyperparameters. If a hyperparameter is not stated in the paper, we use the default configuration from the code. We also refer to the code found in original author’s repository to reproduce the exact figures shown in the paper.

In [3]:
# Load the dataset
mnist_train, mnist_test = (
    MNIST(train=True, device=device),
    MNIST(train=False, device=device),
)
fmnist_train, fmnist_test = (
    FMNIST(train=True, device=device),
    FMNIST(train=False, device=device),
)
mnist_labels = [f"{i}" for i in range(10)]
fmnist_labels = [
    "t-shirt/top",
    "trouser",
    "pullover",
    "dress",
    "coat",
    "sandal",
    "shirt",
    "sneaker",
    "bag",
    "ankle boot",
]

# Instantiate the model
mnist_model = Model.from_config(
    wd=1.0, lr=0.001, batch_size=2048, epochs=100, d_hidden=512, seed=420
).to(device)
fmnist_model = Model.from_config(
    wd=1.0, lr=0.001, batch_size=2048, epochs=100, d_hidden=512, seed=420
).to(device)

# Train both model with random gaussian noise
mnist_metrics = mnist_model.fit(
    mnist_train, mnist_test, RandomGaussianNoise(mean=0, std=0.5, p=1)
)
fmnist_metrics = fmnist_model.fit(
    fmnist_train, fmnist_test, RandomGaussianNoise(mean=0, std=0.5, p=1)
)

train/loss: 0.115, train/acc: 0.964, val/loss: 0.064, val/acc: 0.982: 100%|██████████| 100/100 [01:34<00:00,  1.06it/s]
train/loss: 0.433, train/acc: 0.844, val/loss: 0.403, val/acc: 0.865: 100%|██████████| 100/100 [01:39<00:00,  1.00it/s]


### Figure 2

In [4]:
# Decompose the eigenvalues and eigenvectors
mnist_vals, mnist_vecs = mnist_model.decompose()
fmnist_vals, fmnist_vecs = fmnist_model.decompose()

# Keep only the 1st-5th classes
idxs = slice(1, 6)
vecs = torch.cat([mnist_vecs[idxs, -1], fmnist_vecs[idxs, -1]])
vecs /= (
    vecs.abs().max(1, keepdim=True).values
)  # Normalize eigenvectors to [-1, 1] for visualization

color = dict(color_continuous_scale="RdBu", color_continuous_midpoint=0.0)
fig = px.imshow(
    vecs.view(-1, 28, 28).cpu(),
    facet_col=0,
    facet_col_wrap=5,
    height=330,
    width=1000,
    facet_row_spacing=0.1,
    **color,
)
fig.update_layout(coloraxis_showscale=False, margin=dict(l=0, r=0, b=0, t=20))
fig.update_xaxes(visible=False)
fig.update_yaxes(visible=False)

labels = fmnist_labels[idxs] + mnist_labels[idxs]
[
    a.update(text=f"<b>{labels[i]}</b>", y=a["y"] + 0.005)
    for i, a in enumerate(fig.layout.annotations)
]
fig.show()

In [5]:
# Same as above, but show all classes

# Decompose the eigenvalues and eigenvectors
mnist_vals, mnist_vecs = mnist_model.decompose()
fmnist_vals, fmnist_vecs = fmnist_model.decompose()

vecs = torch.cat([mnist_vecs[:, -1], fmnist_vecs[:, -1]])
vecs /= (
    vecs.abs().max(1, keepdim=True).values
)  # Normalize eigenvectors to [-1, 1] for visualization

color = dict(color_continuous_scale="RdBu", color_continuous_midpoint=0.0)
fig = px.imshow(
    vecs.view(-1, 28, 28).cpu(),
    facet_col=0,
    facet_col_wrap=10,
    height=330,
    width=1000,
    facet_row_spacing=0.1,
    **color,
)
fig.update_layout(coloraxis_showscale=False, margin=dict(l=0, r=0, b=0, t=20))
fig.update_xaxes(visible=False)
fig.update_yaxes(visible=False)

labels = fmnist_labels + mnist_labels
[
    a.update(text=f"<b>{labels[i]}</b>", y=a["y"] + 0.005)
    for i, a in enumerate(fig.layout.annotations)
]
fig.show()

#### Further Analysis

Lets compare the accuracy with bilinear model with relu activation and see if the bilinear layer is competitive.

In [6]:
# Instantiate the model
bilinear_relu_mnist_model = Model.from_config(
    wd=1.0, lr=0.001, batch_size=2048, epochs=100, d_hidden=512, seed=420, gate="relu"
).to(device)
bilinear_relu_fmnist_model = Model.from_config(
    wd=1.0, lr=0.001, batch_size=2048, epochs=100, d_hidden=512, seed=420, gate="relu"
).to(device)

# Train both model with random gaussian noise
bilinear_relu_mnist_metrics = bilinear_relu_mnist_model.fit(
    mnist_train, mnist_test, RandomGaussianNoise(mean=0, std=0.5, p=1)
)
bilinear_relu_fmnist_metrics = bilinear_relu_fmnist_model.fit(
    fmnist_train, fmnist_test, RandomGaussianNoise(mean=0, std=0.5, p=1)
)

train/loss: 0.076, train/acc: 0.976, val/loss: 0.042, val/acc: 0.988: 100%|██████████| 100/100 [01:46<00:00,  1.07s/it]
train/loss: 0.377, train/acc: 0.858, val/loss: 0.362, val/acc: 0.877: 100%|██████████| 100/100 [01:47<00:00,  1.08s/it]


Lets compare the accuracy with linear model with relu activation and see if the bilinear layer is competitive.

In [7]:
# Instantiate the model
linear_mnist_model = LinearModel.from_config(
    wd=1.0, lr=0.001, batch_size=2048, epochs=100, d_hidden=512, seed=420, gate="relu"
).to(device)
linear_fmnist_model = LinearModel.from_config(
    wd=1.0, lr=0.001, batch_size=2048, epochs=100, d_hidden=512, seed=420, gate="relu"
).to(device)

# Train both model with random gaussian noise
linear_mnist_metrics = linear_mnist_model.fit(
    mnist_train, mnist_test, RandomGaussianNoise(mean=0, std=0.5, p=1)
)
linear_fmnist_metrics = linear_fmnist_model.fit(
    fmnist_train, fmnist_test, RandomGaussianNoise(mean=0, std=0.5, p=1)
)

train/loss: 0.114, train/acc: 0.963, val/loss: 0.056, val/acc: 0.983: 100%|██████████| 100/100 [01:32<00:00,  1.08it/s]
train/loss: 0.406, train/acc: 0.847, val/loss: 0.368, val/acc: 0.868: 100%|██████████| 100/100 [01:33<00:00,  1.07it/s]


Check accuracy per class

In [8]:
def accuracy_per_class(y_hat, y, num_classes):
    preds = y_hat.argmax(dim=1)
    accs = []
    for c in range(num_classes):
        idx = y == c
        if idx.any():
            acc = (preds[idx] == y[idx]).float().mean()
        else:
            acc = torch.nan
        accs.append(acc)
    return accs


y_hat = mnist_model(mnist_test.x)
per_class_acc = accuracy_per_class(y_hat, mnist_test.y, 10)

for i, acc in enumerate(per_class_acc):
    print(mnist_labels[i], acc.item() * 100)

0 99.28571581840515
1 99.47136640548706
2 97.67441749572754
3 98.21782112121582
4 98.1670081615448
5 98.43049049377441
6 98.85177612304688
7 96.98443412780762
8 97.7412760257721
9 97.22497463226318


In [9]:
y_hat = fmnist_model(fmnist_test.x)
per_class_acc = accuracy_per_class(y_hat, fmnist_test.y, 10)

for i, acc in enumerate(per_class_acc):
    print(fmnist_labels[i], acc.item() * 100)

t-shirt/top 86.2999975681305
trouser 97.29999899864197
pullover 75.80000162124634
dress 88.20000290870667
coat 87.00000047683716
sandal 90.20000100135803
shirt 54.29999828338623
sneaker 91.20000004768372
bag 97.69999980926514
ankle boot 97.29999899864197


### Figure 3

In [10]:
for i in range(10):
    print(f"Eigenspectrum Plot for {mnist_labels[i]}")
    eigenspectrum_fig = plot_eigenspectrum(mnist_model, digit=i, eigenvectors=4)
    eigenspectrum_fig.show()

Eigenspectrum Plot for 0


Eigenspectrum Plot for 1


Eigenspectrum Plot for 2


Eigenspectrum Plot for 3


Eigenspectrum Plot for 4


Eigenspectrum Plot for 5


Eigenspectrum Plot for 6


Eigenspectrum Plot for 7


Eigenspectrum Plot for 8


Eigenspectrum Plot for 9


In [11]:
for i in range(10):
    print(f"Eigenspectrum Plot for {fmnist_labels[i]}")
    eigenspectrum_fig = plot_eigenspectrum(fmnist_model, digit=i, eigenvectors=4)
    eigenspectrum_fig.show()

Eigenspectrum Plot for t-shirt/top


Eigenspectrum Plot for trouser


Eigenspectrum Plot for pullover


Eigenspectrum Plot for dress


Eigenspectrum Plot for coat


Eigenspectrum Plot for sandal


Eigenspectrum Plot for shirt


Eigenspectrum Plot for sneaker


Eigenspectrum Plot for bag


Eigenspectrum Plot for ankle boot


### Figure 4

In [12]:
all_vecs = torch.empty([5, 10, 10, 28 * 28])
all_vals = torch.empty([5, 10, 512])
all_accs = torch.empty([5])

for i in range(5):
    model = Model.from_config(
        wd=1.0, lr=0.001, batch_size=2048, epochs=50, d_hidden=512, seed=42
    ).to(device)

    torch.set_grad_enabled(True)
    metrics = model.fit(
        mnist_train, mnist_test, RandomGaussianNoise(mean=0, std=i * 0.2, p=1)
    )
    torch.set_grad_enabled(False)

    vals, vecs = model.decompose()
    all_vecs[i] = vecs[..., -10:, :]
    all_vals[i] = vals
    all_accs[i] = metrics["val/acc"].iloc[-1]

vecs, vals, accs = all_vecs, all_vals, all_accs

subset = vecs[:, 0, -1]
subset /= subset.abs().max(1, keepdim=True).values

fig = px.imshow(
    subset.view(-1, 28, 28).cpu(),
    facet_col=0,
    facet_col_wrap=5,
    height=250,
    width=1000,
    **color,
)

fig.update_layout(coloraxis_showscale=False, margin=dict(l=0, r=0, b=20, t=5))
fig.update_xaxes(visible=False)
fig.update_yaxes(visible=False)

[
    a.update(text=f"{accs[i]:.1%}", y=a["y"] - 0.04)
    for i, a in enumerate(fig.layout.annotations)
]

fig.add_annotation(
    x=0,
    y=-0.0,
    xref="paper",
    yref="paper",
    showarrow=True,
    ax=450,
    ay=0,
    axref="pixel",
    ayref="pixel",
    arrowhead=2,
    arrowsize=1,
    arrowwidth=2,
)
fig.add_annotation(
    x=1,
    y=-0.0,
    xref="paper",
    yref="paper",
    showarrow=True,
    ax=-450,
    ay=0,
    axref="pixel",
    ayref="pixel",
    arrowhead=2,
    arrowsize=1,
    arrowwidth=2,
)
fig.add_annotation(
    text="Noise",
    ax=0.5,
    y=-0.05,
    font=dict(size=16),
    xref="paper",
    yref="paper",
    axref="pixel",
    ayref="pixel",
    showarrow=False,
)
fig.add_annotation(
    text="norm=1",
    x=0.97,
    y=-0.1,
    font=dict(size=14),
    xref="paper",
    yref="paper",
    axref="pixel",
    ayref="pixel",
    showarrow=False,
)
fig.add_annotation(
    text="norm=0",
    x=0.02,
    y=-0.1,
    font=dict(size=14),
    xref="paper",
    yref="paper",
    axref="pixel",
    ayref="pixel",
    showarrow=False,
)

fig.show()

train/loss: 0.022, train/acc: 0.995, val/loss: 0.094, val/acc: 0.976: 100%|██████████| 50/50 [00:51<00:00,  1.04s/it]
train/loss: 0.050, train/acc: 0.984, val/loss: 0.071, val/acc: 0.980: 100%|██████████| 50/50 [00:52<00:00,  1.06s/it]
train/loss: 0.100, train/acc: 0.970, val/loss: 0.068, val/acc: 0.981: 100%|██████████| 50/50 [00:56<00:00,  1.13s/it]
train/loss: 0.180, train/acc: 0.944, val/loss: 0.082, val/acc: 0.978: 100%|██████████| 50/50 [00:57<00:00,  1.14s/it]
train/loss: 0.300, train/acc: 0.903, val/loss: 0.113, val/acc: 0.973: 100%|██████████| 50/50 [00:56<00:00,  1.13s/it]


In [13]:
all_vecs = torch.empty([5, 10, 10, 28 * 28])
all_vals = torch.empty([5, 10, 512])
all_accs = torch.empty([5])

for i in range(5):
    model = Model.from_config(
        wd=1.0, lr=0.001, batch_size=2048, epochs=50, d_hidden=512, seed=42
    ).to(device)

    torch.set_grad_enabled(True)
    metrics = model.fit(
        fmnist_train, fmnist_test, RandomGaussianNoise(mean=0, std=i * 0.2, p=1)
    )
    torch.set_grad_enabled(False)

    vals, vecs = model.decompose()
    all_vecs[i] = vecs[..., -10:, :]
    all_vals[i] = vals
    all_accs[i] = metrics["val/acc"].iloc[-1]

vecs, vals, accs = all_vecs, all_vals, all_accs

subset = vecs[:, 0, -1]
subset /= subset.abs().max(1, keepdim=True).values

fig = px.imshow(
    subset.view(-1, 28, 28).cpu(),
    facet_col=0,
    facet_col_wrap=5,
    height=250,
    width=1000,
    **color,
)

fig.update_layout(coloraxis_showscale=False, margin=dict(l=0, r=0, b=20, t=5))
fig.update_xaxes(visible=False)
fig.update_yaxes(visible=False)

[
    a.update(text=f"{accs[i]:.1%}", y=a["y"] - 0.04)
    for i, a in enumerate(fig.layout.annotations)
]

fig.add_annotation(
    x=0,
    y=-0.0,
    xref="paper",
    yref="paper",
    showarrow=True,
    ax=450,
    ay=0,
    axref="pixel",
    ayref="pixel",
    arrowhead=2,
    arrowsize=1,
    arrowwidth=2,
)
fig.add_annotation(
    x=1,
    y=-0.0,
    xref="paper",
    yref="paper",
    showarrow=True,
    ax=-450,
    ay=0,
    axref="pixel",
    ayref="pixel",
    arrowhead=2,
    arrowsize=1,
    arrowwidth=2,
)
fig.add_annotation(
    text="Noise",
    ax=0.5,
    y=-0.05,
    font=dict(size=16),
    xref="paper",
    yref="paper",
    axref="pixel",
    ayref="pixel",
    showarrow=False,
)
fig.add_annotation(
    text="norm=1",
    x=0.97,
    y=-0.1,
    font=dict(size=14),
    xref="paper",
    yref="paper",
    axref="pixel",
    ayref="pixel",
    showarrow=False,
)
fig.add_annotation(
    text="norm=0",
    x=0.02,
    y=-0.1,
    font=dict(size=14),
    xref="paper",
    yref="paper",
    axref="pixel",
    ayref="pixel",
    showarrow=False,
)

fig.show()

train/loss: 0.274, train/acc: 0.907, val/loss: 0.380, val/acc: 0.874: 100%|██████████| 50/50 [00:55<00:00,  1.12s/it]
train/loss: 0.334, train/acc: 0.884, val/loss: 0.367, val/acc: 0.878: 100%|██████████| 50/50 [00:57<00:00,  1.14s/it]
train/loss: 0.412, train/acc: 0.854, val/loss: 0.395, val/acc: 0.868: 100%|██████████| 50/50 [00:58<00:00,  1.17s/it]
train/loss: 0.498, train/acc: 0.820, val/loss: 0.450, val/acc: 0.853: 100%|██████████| 50/50 [00:59<00:00,  1.19s/it]
train/loss: 0.586, train/acc: 0.784, val/loss: 0.519, val/acc: 0.828: 100%|██████████| 50/50 [00:59<00:00,  1.19s/it]
