# 4. Image Classification: Interpreting Visual Features

## 4.2. Quantitative Assessment: Eigenvectors Learn Consistent Patterns

In [1]:
# Import libraries

%load_ext autoreload
%autoreload 2

from itertools import product

import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
import torch
from einops import *
from kornia.augmentation import RandomGaussianNoise
from scipy import stats
from torch.nn.functional import cosine_similarity

from image import MNIST, Model

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = "cpu"

pio.templates.default = "plotly_white"

mnist_train, mnist_test = (
    MNIST(train=True, device=device),
    MNIST(train=False, device=device),
)


def conf_interval(sims, mean, conf=0.95):
    sem = torch.std(sims, dim=-2) / torch.sqrt(torch.tensor(sims.shape[-2]))

    # Calculate 95% confidence interval
    df = sims.shape[-2] - 1
    t_value = stats.t.ppf((1 + conf) / 2, df)
    ci_lower = mean - t_value * sem
    ci_upper = mean + t_value * sem

    return ci_lower, ci_upper

# Similarity

In [3]:
# Features:
# 6 model sizes, 5 seed, 10 classes, 20 top eigenvectors, 784 pixels
features = torch.empty(6, 5, 10, 20, 784)

sizes = [30, 50, 100, 300, 500, 1000]

for d, i in product(range(6), range(5)):
    # Initiate model
    mnist_model = Model.from_config(
        wd=1.0, lr=0.001, batch_size=2048, epochs=100, d_hidden=sizes[d], seed=i
    ).to(device)

    # Train model
    torch.set_grad_enabled(True)
    mnist_model.fit(mnist_train, mnist_test, RandomGaussianNoise(mean=0, std=0.4, p=1))
    torch.set_grad_enabled(False)

    # Take top 20 eigenvectors
    vals, vecs = mnist_model.decompose()
    features[d, i] = vecs[:, :20, :]

train/loss: 0.242, train/acc: 0.931, val/loss: 0.161, val/acc: 0.959: 100%|██████████| 100/100 [01:00<00:00,  1.66it/s]
train/loss: 0.230, train/acc: 0.933, val/loss: 0.150, val/acc: 0.961: 100%|██████████| 100/100 [01:02<00:00,  1.61it/s]
train/loss: 0.235, train/acc: 0.932, val/loss: 0.153, val/acc: 0.959: 100%|██████████| 100/100 [01:04<00:00,  1.56it/s]
train/loss: 0.247, train/acc: 0.930, val/loss: 0.159, val/acc: 0.959: 100%|██████████| 100/100 [01:03<00:00,  1.57it/s]
train/loss: 0.221, train/acc: 0.937, val/loss: 0.146, val/acc: 0.961: 100%|██████████| 100/100 [01:05<00:00,  1.53it/s]
train/loss: 0.183, train/acc: 0.947, val/loss: 0.120, val/acc: 0.966: 100%|██████████| 100/100 [01:03<00:00,  1.57it/s]
train/loss: 0.177, train/acc: 0.948, val/loss: 0.118, val/acc: 0.967: 100%|██████████| 100/100 [01:07<00:00,  1.49it/s]
train/loss: 0.187, train/acc: 0.946, val/loss: 0.121, val/acc: 0.967: 100%|██████████| 100/100 [01:09<00:00,  1.43it/s]
train/loss: 0.189, train/acc: 0.946, val

In [4]:
# Compare to size 300
s = slice(-20, None)
sims = cosine_similarity(
    features[3, None, None, :, :, s, :],  # [1, 1, 10, 20, 784]
    features[:, :, None, :, s, :],  # [6, 5, 1, 10, 20, 784]
    dim=-1,
)  # [6, 5, 5, 10, 20] -> [model_size, seed, reference_seed, class, eigenvector_rank]

# Remove same seed comparison
idxs = torch.triu_indices(5, 5)
sims = rearrange(
    sims[:, idxs[0], idxs[1]].abs(), "... batch cls comp -> ... (batch cls) comp"
)  # [6, 150, 20] -> [model_size, (seed_pair, class), eigenvector_rank]

viridis = plt.cm.get_cmap("viridis")
colors = [viridis(i)[:3] for i in [0.0, 0.25, 0.5, 0.75, 0.9, 1.0]]
colors = [f"rgb({int(r * 255)}, {int(g * 255)}, {int(b * 255)})" for r, g, b in colors]

fig = go.Figure()
for i in range(6):
    mean = torch.mean(sims[i], axis=-2)  # Average over seed/class comparisons
    x = torch.arange(len(mean))
    low, up = conf_interval(sims[i], mean, conf=0.9)  # 90% confidence interval

    # Add mean line
    fig.add_trace(
        go.Scatter(
            x=x,
            y=mean,
            mode="lines",
            name=f"{sizes[i]}",
            line=dict(color=colors[i]),
        )
    )

    # Add error bands
    fig.add_trace(
        go.Scatter(
            x=torch.cat([x, x.flip(0)]),
            y=torch.cat([up, low.flip(0)]),
            fill="toself",
            fillcolor=colors[i].replace("rgb", "rgba").replace(")", ", 0.2)"),
            line=dict(color="rgba(255,255,255,0)"),
            hoverinfo="skip",
            showlegend=False,
        )
    )

# Update layout
fig.update_layout(title="Similarity Across Eigenvectors", title_x=0.5)
fig.update_layout(
    showlegend=True, width=600, height=400, legend_title_text="Model Size"
)
fig.update_xaxes(title="Eigenvector rank")
fig.update_yaxes(title="Cosine similarity", range=[0.00, 1.01])
fig.show()

  viridis = plt.cm.get_cmap("viridis")


In [6]:
sims = cosine_similarity(
    features[:, None, None, :, :, s, :], features[:, :, None, :, s, :], dim=-1
)
idxs = torch.triu_indices(5, 5)
sims = rearrange(
    sims[:, :, idxs[0], idxs[1]].abs(), "... batch cls comp -> ... (batch cls) comp"
)

fig = px.imshow(
    sims[..., 0].mean(-1),
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0.5,
    zmin=0,
    zmax=1,
)
fig.update_layout(margin=dict(l=0, r=0, b=0, t=20))
fig.update_xaxes(ticktext=[f"{sizes[i]}" for i in range(6)], tickvals=torch.arange(6))
fig.update_yaxes(ticktext=[f"{sizes[i]}" for i in range(6)], tickvals=torch.arange(6))
fig.update_layout(width=463, height=400)
fig.update_coloraxes()

# Truncation

In [8]:
def eval_truncated(data, vals, vecs, k=20):
    # Get top k eigenvalues and their indices based on magnitude
    top_k_vals, top_k_indices = vals.abs().topk(k, dim=-1)

    # Use the original signs of the selected eigenvalues
    top_k_vals = torch.gather(vals, -1, top_k_indices)

    # Select corresponding eigenvectors
    # Adjust top_k_indices to match the dimensions of vecs
    expanded_indices = top_k_indices.unsqueeze(-1).expand(-1, -1, vecs.size(-1))
    top_k_vecs = torch.gather(vecs, 1, expanded_indices)

    # Compute using only the top k eigenvalues and eigenvectors
    p = einsum(
        data.flatten(start_dim=1),
        top_k_vecs,
        "batch inp, out hid inp -> batch hid out",
    ).pow(2)
    return einsum(p, top_k_vals, "batch hid out, out hid -> batch out")

In [10]:
results = torch.empty(
    6, 5, 31
)  # -> [model_size, seed, truncation_level], including the case when we use 0 eigenvectors
ground = torch.empty(6, 5)  # -> [model_size, seed]

sizes = [30, 50, 100, 300, 500, 1000]


for d, i in product(range(6), range(5)):
    # Initiate model
    mnist_model = Model.from_config(
        wd=1.0, lr=0.001, batch_size=2048, epochs=100, d_hidden=sizes[d], seed=i
    ).to(device)

    # Train model
    torch.set_grad_enabled(True)
    mnist_model.fit(mnist_train, mnist_test, RandomGaussianNoise(mean=0, std=0.4, p=1))
    torch.set_grad_enabled(False)

    vals, vecs = mnist_model.decompose()

    for k in range(0, 31):
        logits = eval_truncated(mnist_test.x, vals, vecs, k=k)
        results[d, i, k] = (logits.argmax(dim=1) == mnist_test.y).float().mean().cpu()

    ground[d, i] = (
        (mnist_model(mnist_test.x).argmax(dim=1) == mnist_test.y).float().mean().item()
    )

train/loss: 0.242, train/acc: 0.931, val/loss: 0.161, val/acc: 0.959: 100%|██████████| 100/100 [01:00<00:00,  1.65it/s]
train/loss: 0.230, train/acc: 0.933, val/loss: 0.150, val/acc: 0.961: 100%|██████████| 100/100 [01:02<00:00,  1.59it/s]
train/loss: 0.235, train/acc: 0.932, val/loss: 0.153, val/acc: 0.959: 100%|██████████| 100/100 [01:04<00:00,  1.56it/s]
train/loss: 0.247, train/acc: 0.930, val/loss: 0.159, val/acc: 0.959: 100%|██████████| 100/100 [01:09<00:00,  1.44it/s]
train/loss: 0.221, train/acc: 0.937, val/loss: 0.146, val/acc: 0.961: 100%|██████████| 100/100 [01:10<00:00,  1.41it/s]
train/loss: 0.183, train/acc: 0.947, val/loss: 0.120, val/acc: 0.966: 100%|██████████| 100/100 [01:12<00:00,  1.38it/s]
train/loss: 0.177, train/acc: 0.948, val/loss: 0.118, val/acc: 0.967: 100%|██████████| 100/100 [01:12<00:00,  1.38it/s]
train/loss: 0.187, train/acc: 0.946, val/loss: 0.121, val/acc: 0.967: 100%|██████████| 100/100 [01:14<00:00,  1.34it/s]
train/loss: 0.189, train/acc: 0.946, val

In [11]:
fig = go.Figure()

viridis = plt.cm.get_cmap("viridis")
colors = [viridis(i)[:3] for i in [0.0, 0.25, 0.5, 0.75, 0.9, 1.0]]
colors = [f"rgb({int(r * 255)}, {int(g * 255)}, {int(b * 255)})" for r, g, b in colors]

error = 1 - results

for i in range(6):
    mean = torch.mean(error[i], axis=0)
    low, up = conf_interval(error[i], mean, conf=0.9)
    x = torch.arange(len(mean))

    # Add mean line
    fig.add_trace(
        go.Scatter(
            x=x,
            y=mean,
            mode="lines",
            name=f"{sizes[i]}",
            line=dict(color=colors[i]),
        )
    )

    # Add error bands
    fig.add_trace(
        go.Scatter(
            x=torch.cat([x, x.flip(0)]),
            y=torch.cat([up, low.flip(0)]),
            fill="toself",
            fillcolor=colors[i].replace("rgb", "rgba").replace(")", ", 0.2)"),
            line=dict(color="rgba(255,255,255,0)"),
            hoverinfo="skip",
            showlegend=False,
        )
    )

# Update layout
fig.update_layout(title="Truncation Across Sizes", title_x=0.5)
small = lambda x: f"<span style='font-size: 9px;'>{x}</span>"
fig.update_yaxes(
    tickvals=[0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1],
    ticktext=[
        "1%",
        small("2%"),
        small("5%"),
        "10%",
        small("20%"),
        small("50%"),
        "100%",
    ],
    range=[-2.02, 0.02],
    type="log",
)
fig.update_layout(width=600, height=400, legend_title_text="Model Size")
fig.update_xaxes(title="Eigenvector rank (per digit)")
fig.update_yaxes(title="Classification error")

fig.show()


The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead.



In [12]:
fig = go.Figure()

viridis = plt.cm.get_cmap("viridis")
colors = [viridis(i)[:3] for i in [0.0, 0.25, 0.5, 0.75, 0.9, 1.0]]
colors = [f"rgb({int(r * 255)}, {int(g * 255)}, {int(b * 255)})" for r, g, b in colors]

diff = ground[..., None] - results

for i in range(6):
    mean = torch.mean(diff[i], axis=0)
    low, up = conf_interval(diff[i], mean, conf=0.9)
    x = torch.arange(len(mean))

    # Add mean line
    fig.add_trace(
        go.Scatter(
            x=x,
            y=mean,
            mode="lines",
            name=f"{sizes[i]}",
            line=dict(color=colors[i]),
        )
    )

    # Add error bands
    fig.add_trace(
        go.Scatter(
            x=torch.cat([x, x.flip(0)]),
            y=torch.cat([up, low.flip(0)]),
            fill="toself",
            fillcolor=colors[i].replace("rgb", "rgba").replace(")", ", 0.2)"),
            line=dict(color="rgba(255,255,255,0)"),
            hoverinfo="skip",
            showlegend=False,
        )
    )

# Update layout
fig.update_layout(title="Truncation Across Sizes", title_x=0.5)
fig.update_yaxes(
    tickvals=[0.001, 0.01, 0.1, 1],
    ticktext=["0.1%", "1%", "10%", "100%"],
    range=[-3.02, 0.02],
    type="log",
)
fig.update_layout(width=600, height=400, legend_title_text="Model Size")
fig.update_xaxes(title="Eigenvector rank (per digit)")
fig.update_yaxes(title="Accuracy Drop")

fig.show()


The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead.

