# 4. Image Classification: Interpreting Visual Features

## 4.3. Comparing with Ground Truth: Eigenvectors Find Computation

### NOTES

This notebook uses code from the [original author’s repository](https://github.com/tdooms/bilinear-interp) to reproduce results.  
We refer to [this code](https://github.com/tdooms/bilinear-interp/blob/main/_old/workspace/paper/scasper.py) for the implementation used in this notebook.

In [1]:
# Add root to the sys.path

import sys
from pathlib import Path

ROOT = Path().resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))

In [2]:
# Import Libraries

%load_ext autoreload
%autoreload 2

import torch
from torch.nn.functional import cosine_similarity

import plotly.express as px
from einops import *
from kornia.augmentation import RandomGaussianNoise

from image import MNIST, Model, plot_eigenspectrum

device = "cpu"

  from .autonotebook import tqdm as notebook_tqdm
  Referenced from: <063DD42B-A1DA-3BBE-AD5E-03C12F7E3DBA> /Users/shidqietaufiqurrahman/workspace/miniconda3/envs/fact-ai/lib/python3.12/site-packages/torchvision/image.so
  warn(


In [3]:
def make_label_one_similarity(x, target):
    pos_sims = cosine_similarity(x.flatten(1), target.flatten(1))
    neg_sims = cosine_similarity(x.flatten(1), (1 - target).flatten(1))
    return ((pos_sims > 0.4) | (neg_sims > 0.4)).long()

In [4]:
mnist_train, mnist_test = (
    MNIST(train=True, device=device),
    MNIST(train=False, device=device),
)
target = mnist_train.x[6]
mnist_train.y = make_label_one_similarity(mnist_train.x, target)
mnist_test.y = make_label_one_similarity(mnist_test.x, target)

mnist_model = Model.from_config(
    wd=1.0,
    lr=0.001,
    batch_size=2048,
    epochs=100,
    d_hidden=64,
    d_output=2,
    bias=True,
).to(device)

torch.set_grad_enabled(True)
metrics = mnist_model.fit(
    mnist_train,
    mnist_test,
    RandomGaussianNoise(mean=0, std=0.1, p=1),
)
torch.set_grad_enabled(False)

w_l = torch.block_diag(mnist_model.w_l[0], torch.eye(1, device="cpu"))
w_l[:-1, -1] = mnist_model.blocks[0].bias.chunk(2)[0]

w_r = torch.block_diag(mnist_model.w_r[0], torch.eye(1, device="cpu"))
w_r[:-1, -1] = mnist_model.blocks[0].bias.chunk(2)[1]

w_u = torch.cat([mnist_model.w_u, torch.tensor([[1], [1]], device="cpu")], dim=1)
w_e = torch.block_diag(mnist_model.w_e, torch.eye(1, device="cpu"))

b = einsum(w_u[1], w_l, w_r, "out, out in1, out in2 -> in1 in2")
b = 0.5 * (b + b.mT)

vals, vecs = torch.linalg.eigh(b)
vecs = einsum(vecs, w_e, "emb batch, emb inp -> batch inp")


train/loss: 0.120, train/acc: 0.954, val/loss: 0.110, val/acc: 0.965: 100%|██████████| 100/100 [00:56<00:00,  1.78it/s]


### Figure 6: Eigenvalues and eigenvectors of a model trained to classify based on similarity to a target.

In [8]:
# Top eigenvector, ignoring the bias component
color = dict(color_continuous_scale="RdBu", color_continuous_midpoint=0.0)
px.imshow(
    vecs[-1:, :-1].view(-1, 28, 28).cpu(),
    facet_col=0,
    facet_col_wrap=5,
    height=300,
    **color,
).show()

In [9]:
# Visualize the bias
c = 2 * b[:-1, -1]  # linear bias over pixels
c_pixel = mnist_model.w_e.T @ c  # project to pixel space
px.imshow(
    c_pixel.view(28, 28).cpu(),
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0.0,
    title="Linear Bias Term (Pixel Space)",
).show()


In [10]:
# Visualize target image
target = mnist_train.x[6]

fig = px.imshow(target.cpu().view(28, 28), **color, width=200, height=200)
fig.update_layout(coloraxis_showscale=False, margin=dict(l=0, r=0, b=0, t=5))
fig.update_xaxes(visible=False).update_yaxes(visible=False)
fig

In [11]:
eigenspectrum_fig = plot_eigenspectrum(mnist_model, digit=1, eigenvectors=3)
eigenspectrum_fig.show()