In [None]:
import urllib
from functools import partial
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
from urllib.request import urlopen

import matplotlib.pyplot as plt
import numpy as np
import plenoptic as po
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from plenoptic.tools.display import clean_up_axes
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from timm.utils import AttentionExtract
from torch import Tensor
from torchvision import transforms
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names

timm.layers.set_fused_attn(False)  # Expose all attention internals

# needed for the plotting/animating:
%matplotlib inline
plt.rcParams["animation.html"] = "html5"
# use single-threaded ffmpeg for animation writer
plt.rcParams["animation.writer"] = "ffmpeg"
plt.rcParams["animation.ffmpeg_args"] = ["-threads", "1"]

In [None]:
# DINO, with classification head (untrained):
model = timm.create_model("timm/vit_small_patch16_224.dino", pretrained=True)

# CLIP model, fine-tuned on ImageNet:
# model = timm.create_model("timm/vit_base_patch32_clip_224.laion2b_ft_in12k_in1k", pretrained=True)

# Set model to eval mode for inference
model.eval()

# Create Transform
transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))

In [None]:
model.pretrained_cfg

In [None]:
try:
    labels = model.pretrained_cfg["label_names"]
except KeyError:
    # Default to ImageNet if no labels are provided
    url = "https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt"
    labels = urllib.request.urlopen(url).read().decode("utf-8").splitlines()
resolve_data_config(model.pretrained_cfg, model=model)

In [None]:
print(labels[:10])

In [None]:
# Download and open the image
url = "https://github.com/EliSchwartz/imagenet-sample-images/blob/master/n02110958_pug.JPEG?raw=true"

# url = "https://github.com/EliSchwartz/imagenet-sample-images/blob/master/n01491361_tiger_shark.JPEG?raw=true"

original_img = Image.open(urlopen(url))

img = transforms.PILToTensor()(original_img)

if img.shape[0] == 1:
    img = img.repeat(3, 1, 1)

img = img.unsqueeze(0).float().to(0)

if img.max() > 1:
    img = img / 255.0

print(img.shape)
po.imshow(img, as_rgb=True);

In [None]:
img_resnet_ready = transform(img)

In [None]:
po.imshow(img_resnet_ready, as_rgb=True);

In [None]:
train_nodes, eval_nodes = get_graph_node_names(model)

In [None]:
eval_nodes[-10:]

In [None]:
class IntermediateOutputViT(nn.Module):
    def __init__(self, model: nn.Module, block_index: int, transform: Optional[Callable] = None):
        super().__init__()
        self.block_index = block_index
        self.attention_output = f"blocks.{block_index}.attn.attn_drop"
        self.feature_representation = f"blocks.{block_index}"
        self.extractor = create_feature_extractor(
            model, return_nodes=[self.attention_output, self.feature_representation]
        )
        self.model = model
        self.transform = transform

    def _extractor(self, x):
        if self.transform is not None:
            x = self.transform(x)
        return self.extractor(x)

    def forward(self, x):
        return self._extractor(x)[self.feature_representation]

    def plot_representation(
        self,
        data: Tensor,
        ax: Optional[plt.Axes] = None,
        figsize: Tuple[float, float] = (15, 15),
        ylim: Optional[Union[Tuple[float, float], Literal[False]]] = None,
        batch_idx: int = 0,
        title: Optional[str] = None,
    ) -> Tuple[plt.Figure, List[plt.Axes]]:
        feature_representation = data[batch_idx]

        class_token_representation = feature_representation[0].squeeze().detach().cpu().numpy()
        spatial_representation = feature_representation[1:]

        dim_average_representation = spatial_representation.mean(1)
        num_patches = int(dim_average_representation.shape[0] ** 0.5)
        patch_representation_grid = dim_average_representation.reshape(num_patches, num_patches).detach().cpu().numpy()

        # Determine figure layout
        if ax is None:
            fig, axes = plt.subplots(2, 1, figsize=figsize, gridspec_kw={"height_ratios": [1, 1]})
        else:
            ax = clean_up_axes(ax, False, ["top", "right", "bottom", "left"], ["x", "y"])
            gs = ax.get_subplotspec().subgridspec(2, 1, height_ratios=[3, 1])
            fig = ax.figure
            axes = [fig.add_subplot(gs[0]), fig.add_subplot(gs[1])]

        # Plot average error across channels
        po.imshow(
            ax=axes[0],
            image=patch_representation_grid[None, None, ...],
            title=f"{title} - patch tokens" if title is not None else "Average Representation across patch tokens",
            vrange="auto0",
        )

        # Plot the class token representation
        axes[1].plot(class_token_representation)
        axes[1].set_xlabel("Dimension")
        axes[1].set_ylabel("Value")
        axes[1].set_title("Class Token Representation") if title is None else axes[1].set_title(
            f"{title} - Class Token Representation"
        )

        return fig, axes

    def plot_attention(
        self,
        x: Tensor,
        ax: Optional[plt.Axes] = None,
        figsize: Tuple[float, float] = (15, 15),
        ylim: Optional[Union[Tuple[float, float], Literal[False]]] = None,
        batch_idx: int = 0,
        title: Optional[str] = None,
        head_fusion: str = "mean",
    ) -> Tuple[plt.Figure, List[plt.Axes]]:
        attn_map = self._extractor(x)[self.attention_output]

        attn_map = attn_map[batch_idx]  # Remove batch dimension

        if head_fusion == "mean_std":
            attn_map = attn_map.mean(0) / attn_map.std(0)
        elif head_fusion == "mean":
            attn_map = attn_map.mean(0)
        elif head_fusion == "max":
            attn_map = attn_map.amax(0)
        elif head_fusion == "min":
            attn_map = attn_map.amin(0)
        else:
            raise ValueError(f"Invalid head fusion method: {head_fusion}")

        # Use the first token's attention (in most ViTs the class token)
        attn_map = attn_map[0]

        # Reshape the attention map to 2D
        num_patches = int(attn_map.shape[0] ** 0.5)
        attn_map = attn_map[1:].reshape(num_patches, num_patches)

        # Interpolate to match image size
        attn_map = attn_map.clone().detach().unsqueeze(0).unsqueeze(0)
        attn_map = F.interpolate(attn_map, size=(64, 64), mode="bilinear", align_corners=False)
        attn_map = attn_map.squeeze().cpu().numpy()

        # Normalize attention map
        attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min())

        # Determine figure layout
        if ax is None:
            fig, axes = plt.subplots(1, 1, figsize=figsize)
        else:
            ax = clean_up_axes(ax, False, ["top", "right", "bottom", "left"], ["x", "y"])
            gs = ax.get_subplotspec().subgridspec(1, 1)
            fig = ax.figure
            axes = fig.add_subplot(gs[0])

        # Plot the attention map

        po.imshow(
            attn_map[None, None, ...],
            ax=axes,
            title=title if title is not None else f"Self-attention map for CLS token @ block {self.block_index}",
        )

        return fig, axes

In [None]:
test_model = IntermediateOutputViT(model, 11, transform)
test_model.to(0);

In [None]:
po.tools.remove_grad(test_model)

test_model.eval()

po.tools.validate.validate_model(test_model, device=0, image_shape=(1, 3, 224, 224))

In [None]:
test_model(img.to(0)).shape

In [None]:
test_model.plot_representation(test_model(img.to(0)), title="Representation at block 11");

In [None]:
test_model.plot_attention(img.to(0), figsize=(5, 5));

In [None]:
def low_pass_gaussian(img, kernel_size=11, sigma=5):
    """Applies a Gaussian blur to low-pass filter the image."""
    blur = transforms.GaussianBlur(kernel_size=kernel_size, sigma=sigma)
    return blur(img)


met = po.synth.Metamer(
    img,
    test_model,
    # initial_image=low_pass_gaussian(img)
)
optim = torch.optim.AdamW([met.metamer], lr=5e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, "min", patience=50, factor=0.5, verbose=False)

To synthesize the model metamer, we use the synthesize method. Setting the `store_progress` arg stores copies of the model metamer over time, which will allow us to visualize synthesis progress after the fact:

In [None]:
met.synthesize(
    5000, store_progress=10, optimizer=optim, scheduler=scheduler, stop_criterion=1e-6, stop_iters_to_check=100
)

The plot on the left shows the model metamer, the middle plot shows the synthesis loss, and the plot on the left shows the model representation error:

In [None]:
po.synth.metamer.plot_synthesis_status(met, ylim=False, iteration=-1, zoom=1, figsize=(25, 10));

In [None]:
fig, axes = plt.subplots(1, 4, figsize=(16, 4))

po.imshow(transform(img), ax=axes[0], title="Input image")

test_model.plot_attention(img.to(0), ax=axes[1], title="Attention map for original image")

test_model.plot_attention(met.metamer, ax=axes[2], title="Attention map for metamer")

attn_difference = (
    test_model._extractor(met.metamer)[test_model.attention_output]
    - test_model._extractor(img.to(0))[test_model.attention_output]
)

# Squeeze batch out and average over heads
attn_difference = attn_difference.squeeze().mean(0)

# Use the first token's attention (in most ViTs the class token)
attn_map_diff = attn_difference[0]

print(attn_map_diff.shape)

# Reshape the attention map to 2D
num_patches = int(attn_map_diff.shape[0] ** 0.5)
attn_map_diff = attn_map_diff[1:].reshape(num_patches, num_patches)

# Interpolate to match image size
attn_map_diff = attn_map_diff.clone().detach().unsqueeze(0).unsqueeze(0)
attn_map_diff = F.interpolate(attn_map_diff, size=(64, 64), mode="bilinear", align_corners=False)
attn_map_diff = attn_map_diff.squeeze().cpu().numpy()


po.imshow(attn_map_diff[None, None, ...], ax=axes[3], title="Difference in attention maps")

plt.suptitle(
    f"Attention map comparison between original image and metamer for representations @ block {test_model.block_index}"
);

In [None]:
print("Class label for the original model:")
print(labels[torch.argmax(model(transform(img).to(0)).squeeze()).item()])

In [None]:
print("Class label for the metamer:")
print(labels[torch.argmax(model(transform(met.metamer).to(0)).squeeze()).item()])

In [None]:
# Get predictions for both images
with torch.no_grad():
    original_logits = model(transform(img).to(0)).squeeze()
    metamer_logits = model(transform(met.metamer).to(0)).squeeze()

# Convert to probabilities
original_probs = F.softmax(original_logits, dim=0)
metamer_probs = F.softmax(metamer_logits, dim=0)

# Get top predictions for both images (for labeling key points)
k = 10
top_original = torch.topk(original_probs, k)
top_metamer = torch.topk(metamer_probs, k)
combined_top_indices = torch.unique(torch.cat([top_original.indices, top_metamer.indices]))

# Create a figure with two plots
fig, axs = plt.subplots(1, 2, figsize=(18, 8))

# Plot 1: Scatter plot of probabilities
axs[0].scatter(
    original_probs[combined_top_indices].cpu().numpy(),
    metamer_probs[combined_top_indices].cpu().numpy(),
    alpha=0.7,
    s=100,
)

# Add diagonal line
max_prob = max(original_probs.max().item(), metamer_probs.max().item())
axs[0].plot([0, max_prob], [0, max_prob], "k--", alpha=0.5)

# Label key points
for idx in combined_top_indices:
    axs[0].annotate(labels[idx], (original_probs[idx].item(), metamer_probs[idx].item()), fontsize=9)

axs[0].set_xlabel("Original Image Probability")
axs[0].set_ylabel("Metamer Probability")
axs[0].set_title("Class Probabilities Comparison")
axs[0].grid(True, alpha=0.3)

# Plot 2: Bar chart of top predictions
x = np.arange(len(combined_top_indices))
width = 0.35

axs[1].bar(x - width / 2, original_probs[combined_top_indices].cpu().numpy(), width, label="Original")
axs[1].bar(x + width / 2, metamer_probs[combined_top_indices].cpu().numpy(), width, label="Metamer")

axs[1].set_xticks(x)
axs[1].set_xticklabels([labels[i] for i in combined_top_indices], rotation=45, ha="right")
axs[1].set_ylabel("Probability")
axs[1].set_title("Top Class Probabilities")
axs[1].legend()

plt.tight_layout()
plt.show()