# 4. Image Classification: Interpreting Visual Features

## 4.1. Qualitative Assessment: Top Eigenvectors Appear Interpretable

In [1]:
# Import libraries

%load_ext autoreload
%autoreload 2

import torch
import plotly.express as px
import plotly.graph_objects as go
from kornia.augmentation import RandomGaussianNoise
from plotly.subplots import make_subplots

from image import CIFAR10, Model, SVHN

device = "cpu"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load the dataset
cifar10_train, cifar10_test = (
    CIFAR10(train=True, device=device, grayscale=True),
    CIFAR10(train=False, device=device, grayscale=True),
)
svhn_train, svhn_test = (
    SVHN(split="train", device=device, grayscale=True),
    SVHN(split="test", device=device, grayscale=True),
)
cifar10_labels = [
    "airplane",
    "automobile",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
]
svhn_labels = [f"{i}" for i in range(10)]

# Instantiate the model
cifar10_model = Model.from_config(
    wd=1.0, lr=0.001, batch_size=2048, epochs=100, d_input=1024, d_hidden=512, seed=420
).to(device)
svhn_model = Model.from_config(
    wd=1.0, lr=0.001, batch_size=2048, epochs=100, d_input=1024, d_hidden=512, seed=420
).to(device)

# Train both model with random gaussian noise
cifar10_metrics = cifar10_model.fit(
    cifar10_train, cifar10_test, RandomGaussianNoise(mean=0, std=0.5, p=1)
)
svhn_metrics = svhn_model.fit(
    svhn_train, svhn_test, RandomGaussianNoise(mean=0, std=0.5, p=1)
)

Files already downloaded and verified
Files already downloaded and verified
Using downloaded and verified file: ./data/train_32x32.mat
Using downloaded and verified file: ./data/test_32x32.mat


train/loss: 1.721, train/acc: 0.393, val/loss: 1.696, val/acc: 0.403: 100%|██████████| 100/100 [01:28<00:00,  1.13it/s]
train/loss: 1.490, train/acc: 0.499, val/loss: 1.194, val/acc: 0.760: 100%|██████████| 100/100 [02:23<00:00,  1.43s/it]


### Figure 2

In [3]:
# Decompose the eigenvalues and eigenvectors
cifar10_vals, cifar10_vecs = cifar10_model.decompose()
svhn_vals, svhn_vecs = svhn_model.decompose()

# Use all eigenvectors
vecs = torch.cat([cifar10_vecs[:, -1], svhn_vecs[:, -1]])
vecs /= (
    vecs.abs().max(1, keepdim=True).values
)  # Normalize eigenvectors to [-1, 1] for visualization

color = dict(color_continuous_scale="RdBu", color_continuous_midpoint=0.0)
fig = px.imshow(
    vecs.view(-1, 32, 32).cpu(),
    facet_col=0,
    facet_col_wrap=10,
    height=330,
    width=1000,
    facet_row_spacing=0.1,
    **color,
)
fig.update_layout(coloraxis_showscale=False, margin=dict(l=0, r=0, b=0, t=20))
fig.update_xaxes(visible=False)
fig.update_yaxes(visible=False)

labels = svhn_labels + cifar10_labels
[
    a.update(text=f"<b>{labels[i]}</b>", y=a["y"] + 0.005)
    for i, a in enumerate(fig.layout.annotations)
]
fig.show()

### Figure 3

In [4]:
# Change the plot_eigenspectrum to output 32x32 image
def plot_eigenspectrum(
    model, digit, eigenvectors=3, eigenvalues=20, ignore_pos=[], ignore_neg=[]
):
    """Plot the eigenspectrum for a given digit."""
    colors = px.colors.qualitative.Plotly
    fig = make_subplots(rows=2, cols=1 + eigenvectors)

    vals, vecs = model.decompose()
    vals, vecs = vals[digit].cpu(), vecs[digit].cpu()

    negative = torch.arange(eigenvectors)
    positive = -1 - negative

    fig.add_trace(
        go.Scatter(y=vals[-eigenvalues - 2 :].flip(0), mode="lines"), row=1, col=1
    )
    fig.add_trace(
        go.Scatter(
            x=negative.flip(0),
            y=vals[positive].flip(0),
            mode="markers",
            marker=dict(color=colors[0]),
        ),
        row=1,
        col=1,
    )

    fig.add_trace(
        go.Scatter(
            y=vals[: eigenvalues + 2], mode="lines", marker=dict(color=colors[1])
        ),
        row=2,
        col=1,
    )
    fig.add_trace(
        go.Scatter(
            x=negative, y=vals[negative], mode="markers", marker=dict(color=colors[1])
        ),
        row=2,
        col=1,
    )

    for i, idx in enumerate(positive):
        fig.add_trace(
            go.Heatmap(
                z=vecs[idx].view(32, 32).flip(0),
                colorscale="RdBu",
                zmid=0,
                showscale=False,
            ),
            row=1,
            col=i + 2,
        )

    for i, idx in enumerate(negative):
        fig.add_trace(
            go.Heatmap(
                z=vecs[idx].view(32, 32).flip(0),
                colorscale="RdBu",
                zmid=0,
                showscale=False,
            ),
            row=2,
            col=i + 2,
        )

    fig.update_xaxes(visible=False).update_yaxes(visible=False)
    fig.update_xaxes(
        visible=True,
        tickvals=[eigenvalues],
        ticktext=[f"{eigenvalues}"],
        zeroline=False,
        col=1,
    )
    fig.update_yaxes(zeroline=True, rangemode="tozero", col=1)

    tickvals = [0] + [
        x.item() for i, x in enumerate(vals[positive]) if i not in ignore_pos
    ]
    ticktext = [f"{val:.2f}" for val in tickvals]

    fig.update_yaxes(visible=True, tickvals=tickvals, ticktext=ticktext, col=1, row=1)

    tickvals = [0] + [
        x.item() for i, x in enumerate(vals[negative]) if i not in ignore_neg
    ]
    ticktext = [f"{val:.2f}" for val in tickvals]
    fig.update_yaxes(visible=True, tickvals=tickvals, ticktext=ticktext, col=1, row=2)

    fig.update_coloraxes(showscale=False)
    fig.update_layout(
        autosize=False,
        width=170 * (eigenvectors + 1),
        height=300,
        margin=dict(l=0, r=0, b=0, t=0),
        template="plotly_white",
    )
    fig.update_legends(visible=False)

    return fig


In [5]:
# Visualize the top and bottom 4 eigenvectors in SVHN classification task
for i in range(10):
    print(f"Eigenspectrum Plot for {svhn_labels[i]}")
    eigenspectrum_fig = plot_eigenspectrum(svhn_model, digit=i, eigenvectors=4)
    eigenspectrum_fig.show()

Eigenspectrum Plot for 0


Eigenspectrum Plot for 1


Eigenspectrum Plot for 2


Eigenspectrum Plot for 3


Eigenspectrum Plot for 4


Eigenspectrum Plot for 5


Eigenspectrum Plot for 6


Eigenspectrum Plot for 7


Eigenspectrum Plot for 8


Eigenspectrum Plot for 9


In [5]:
# Visualize the top and bottom 4 eigenvectors in CIFAR-10 classification task
for i in range(10):
    print(f"Eigenspectrum Plot for {cifar10_labels[i]}")
    eigenspectrum_fig = plot_eigenspectrum(cifar10_model, digit=i, eigenvectors=4)
    eigenspectrum_fig.show()

Eigenspectrum Plot for airplane


Eigenspectrum Plot for automobile


Eigenspectrum Plot for bird


Eigenspectrum Plot for cat


Eigenspectrum Plot for deer


Eigenspectrum Plot for dog


Eigenspectrum Plot for frog


Eigenspectrum Plot for horse


Eigenspectrum Plot for ship


Eigenspectrum Plot for truck


# Filter Label

In [6]:
# Load the dataset
filtered_cifar10_train, filtered_cifar10_test = (
    CIFAR10(train=True, device=device, grayscale=True, labels=[5, 8]),
    CIFAR10(train=False, device=device, grayscale=True, labels=[5, 8]),
)
filtered_cifar10_labels = ["dog", "ship"]

# Instantiate the model
filtered_cifar10_model = Model.from_config(
    wd=1.0,
    lr=0.001,
    batch_size=2048,
    epochs=100,
    n_layer=1,
    d_input=1024,
    d_hidden=512,
    d_output=2,
    seed=420,
).to(device)

print(filtered_cifar10_model)

Files already downloaded and verified
Files already downloaded and verified
Model(
  (embed): Linear(
    in_features=1024, out_features=512, bias=False
    (gate): Identity()
  )
  (blocks): ModuleList(
    (0): Bilinear(
      in_features=512, out_features=1024, bias=False
      (gate): Identity()
    )
  )
  (head): Linear(
    in_features=512, out_features=2, bias=False
    (gate): Identity()
  )
  (criterion): CrossEntropyLoss()
)


In [7]:
# Train both model with random gaussian noise
filtered_cifar10_metrics = filtered_cifar10_model.fit(
    filtered_cifar10_train,
    filtered_cifar10_test,
    RandomGaussianNoise(mean=0, std=0.5, p=1),
)

train/loss: 0.391, train/acc: 0.824, val/loss: 0.387, val/acc: 0.820: 100%|██████████| 100/100 [00:14<00:00,  6.80it/s]


In [8]:
# Visualize the top and bottom 4 eigenvectors in FMNIST classification task
for i in range(2):
    print(f"Eigenspectrum Plot for {filtered_cifar10_labels[i]}")
    eigenspectrum_fig = plot_eigenspectrum(
        filtered_cifar10_model, digit=i, eigenvectors=4
    )
    eigenspectrum_fig.show()

Eigenspectrum Plot for dog


Eigenspectrum Plot for ship


# RGB

In [9]:
# Load the dataset
rgb_cifar10_train, rgb_cifar10_test = (
    CIFAR10(train=True, device=device, grayscale=False),
    CIFAR10(train=False, device=device, grayscale=False),
)
rgb_cifar10_labels = [
    "airplane",
    "automobile",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
]

# Instantiate the model (note the input become 3072)
rgb_cifar10_model = Model.from_config(
    wd=1.0, lr=0.001, batch_size=2048, epochs=100, d_input=3072, d_hidden=512, seed=420
).to(device)

# Train both model with random gaussian noise
rgb_cifar10_metrics = rgb_cifar10_model.fit(
    rgb_cifar10_train, rgb_cifar10_test, RandomGaussianNoise(mean=0, std=0.5, p=1)
)

Files already downloaded and verified
Files already downloaded and verified


train/loss: 1.455, train/acc: 0.491, val/loss: 1.433, val/acc: 0.506: 100%|██████████| 100/100 [03:34<00:00,  2.14s/it]


In [10]:
# Modify the plot_eigenspectrum to get rgb eigenvectors
def plot_eigenspectrum_rgb(
    model, digit, eigenvectors=3, eigenvalues=20, ignore_pos=[], ignore_neg=[]
):
    """Plot the eigenspectrum for a given digit."""
    colors = px.colors.qualitative.Plotly
    fig = make_subplots(rows=2, cols=1 + eigenvectors)

    vals, vecs = model.decompose()
    vals, vecs = vals[digit].cpu(), vecs[digit].cpu()

    negative = torch.arange(eigenvectors)
    positive = -1 - negative

    fig.add_trace(
        go.Scatter(y=vals[-eigenvalues - 2 :].flip(0), mode="lines"), row=1, col=1
    )
    fig.add_trace(
        go.Scatter(
            x=negative.flip(0),
            y=vals[positive].flip(0),
            mode="markers",
            marker=dict(color=colors[0]),
        ),
        row=1,
        col=1,
    )

    fig.add_trace(
        go.Scatter(
            y=vals[: eigenvalues + 2], mode="lines", marker=dict(color=colors[1])
        ),
        row=2,
        col=1,
    )
    fig.add_trace(
        go.Scatter(
            x=negative, y=vals[negative], mode="markers", marker=dict(color=colors[1])
        ),
        row=2,
        col=1,
    )

    for i, idx in enumerate(positive):
        vec_img = vecs[idx].view(3, 32, 32).mean(0)  # NOTE: average RGB
        fig.add_trace(
            go.Heatmap(
                z=vec_img.flip(0),
                colorscale="RdBu",
                zmid=0,
                showscale=False,
            ),
            row=1,
            col=i + 2,
        )

    for i, idx in enumerate(negative):
        vec_img = vecs[idx].view(3, 32, 32).mean(0)  # NOTE: average RGB
        fig.add_trace(
            go.Heatmap(
                z=vec_img.flip(0),
                colorscale="RdBu",
                zmid=0,
                showscale=False,
            ),
            row=2,
            col=i + 2,
        )

    fig.update_xaxes(visible=False).update_yaxes(visible=False)
    fig.update_xaxes(
        visible=True,
        tickvals=[eigenvalues],
        ticktext=[f"{eigenvalues}"],
        zeroline=False,
        col=1,
    )
    fig.update_yaxes(zeroline=True, rangemode="tozero", col=1)

    tickvals = [0] + [
        x.item() for i, x in enumerate(vals[positive]) if i not in ignore_pos
    ]
    ticktext = [f"{val:.2f}" for val in tickvals]

    fig.update_yaxes(visible=True, tickvals=tickvals, ticktext=ticktext, col=1, row=1)

    tickvals = [0] + [
        x.item() for i, x in enumerate(vals[negative]) if i not in ignore_neg
    ]
    ticktext = [f"{val:.2f}" for val in tickvals]
    fig.update_yaxes(visible=True, tickvals=tickvals, ticktext=ticktext, col=1, row=2)

    fig.update_coloraxes(showscale=False)
    fig.update_layout(
        autosize=False,
        width=170 * (eigenvectors + 1),
        height=300,
        margin=dict(l=0, r=0, b=0, t=0),
        template="plotly_white",
    )
    fig.update_legends(visible=False)

    return fig


In [11]:
# Visualize the top and bottom 4 eigenvectors in CIFAR-10 classification task
for i in range(10):
    print(f"Eigenspectrum Plot for {rgb_cifar10_labels[i]}")
    eigenspectrum_fig = plot_eigenspectrum_rgb(
        rgb_cifar10_model, digit=i, eigenvectors=4
    )
    eigenspectrum_fig.show()

Eigenspectrum Plot for airplane


Eigenspectrum Plot for automobile


Eigenspectrum Plot for bird


Eigenspectrum Plot for cat


Eigenspectrum Plot for deer


Eigenspectrum Plot for dog


Eigenspectrum Plot for frog


Eigenspectrum Plot for horse


Eigenspectrum Plot for ship


Eigenspectrum Plot for truck


# Filtered RGB

In [12]:
# Load the dataset
rgb_filtered_cifar10_train, rgb_filtered_cifar10_test = (
    CIFAR10(train=True, device=device, grayscale=False, labels=[5, 8]),
    CIFAR10(train=False, device=device, grayscale=False, labels=[5, 8]),
)
rgb_filtered_cifar10_labels = [
    "dog",
    "ship",
]

# Instantiate the model
rgb_filtered_cifar10_model = Model.from_config(
    wd=1.0,
    lr=0.001,
    batch_size=2048,
    epochs=100,
    n_layer=1,
    d_input=3072,
    d_hidden=512,
    d_output=2,
    seed=420,
).to(device)

print(rgb_filtered_cifar10_model)

Files already downloaded and verified
Files already downloaded and verified
Model(
  (embed): Linear(
    in_features=3072, out_features=512, bias=False
    (gate): Identity()
  )
  (blocks): ModuleList(
    (0): Bilinear(
      in_features=512, out_features=1024, bias=False
      (gate): Identity()
    )
  )
  (head): Linear(
    in_features=512, out_features=2, bias=False
    (gate): Identity()
  )
  (criterion): CrossEntropyLoss()
)


In [13]:
# Train both model with random gaussian noise
rgb_filtered_cifar10_metrics = rgb_filtered_cifar10_model.fit(
    rgb_filtered_cifar10_train,
    rgb_filtered_cifar10_test,
    RandomGaussianNoise(mean=0, std=0.5, p=1),
)

train/loss: 0.268, train/acc: 0.892, val/loss: 0.289, val/acc: 0.889: 100%|██████████| 100/100 [00:36<00:00,  2.75it/s]


In [14]:
# Visualize the top and bottom 4 eigenvectors in CIFAR-10 classification task
for i in range(2):
    print(f"Eigenspectrum Plot for {rgb_filtered_cifar10_labels[i]}")
    eigenspectrum_fig = plot_eigenspectrum_rgb(
        rgb_filtered_cifar10_model, digit=i, eigenvectors=4
    )
    eigenspectrum_fig.show()

Eigenspectrum Plot for dog


Eigenspectrum Plot for ship
