# Vision models

## Orientation: ResNet-18


1. Let's start by training a simple ResNet-18 model and take lots of checkpoints.
2. Then do feature visualization on the end results (for a random sample of neurons). 
3. Look at how the activation of the target neuron reacts to those feature visualizations over the course of training.

In [5]:
import os
from dataclasses import dataclass, field
from typing import Optional, Container, Tuple, List, Dict, TypedDict
from dataclasses import asdict
import math
from typing import Callable
import functools
import random
import logging

import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
import torchvision.utils as vutils
from PIL import Image
from tqdm.notebook import tqdm
import numpy as np
import pandas as pd
import wandb 
from torch.optim.lr_scheduler import LambdaLR 
from dotenv import load_dotenv
import matplotlib.pyplot as plt

from devinterp.config import Config, OptimizerConfig, SchedulerConfig
from devinterp.checkpoints import CheckpointManager
from devinterp.logging import Logger
from devinterp.learner import Learner
from devinterp.misc.io import gen_images, show_images
from devinterp.viz.activations import FeatureVisualizer, ActivationProbe
from devinterp.misc.io import show_images



load_dotenv("../.env")
# wandb.finish()
logging.basicConfig(level=logging.INFO)

In [2]:
torch.manual_seed(0)
model: nn.Module = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False)

Using cache found in /Users/Jesse/.cache/torch/hub/pytorch_vision_v0.10.0


In [3]:
# Data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_transforms = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

train_set = datasets.CIFAR10(root='../data', train=True, download=True, transform=train_transforms)
test_set = datasets.CIFAR10(root='../data', train=False, download=True, transform=transform)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
from devinterp.config import Config
import yaml

config = Config(
    num_training_samples=len(train_set), 
    num_steps=64_000, 
    project="resnet18", 
    entity="devinterp", 
    logging_steps=(100, 100), 
    checkpoint_steps=(25, 25),
    optimizer_config=OptimizerConfig(
        optimizer_type="SGD",
        lr=0.1,
        momentum=0.9,
    ),
    scheduler_config=SchedulerConfig(
        scheduler_type="MultiStepLR",
        milestones=[16_000, 32_000, 48_000], 
        gamma=0.5
    ),
)

print(yaml.dump(config.model_dump(exclude=("logging_steps", "checkpoint_steps"))))

INFO:devinterp.config:Logging to wandb enabled (project: resnet18, entity: devinterp)
INFO:devinterp.config:batch_size: 128
checkpoint_steps: !!set
  0: null
  1: null
  2: null
  3: null
  6: null
  10: null
  15: null
  25: null
  40: null
  63: null
  100: null
  159: null
  252: null
  401: null
  636: null
  1008: null
  1600: null
  2537: null
  2666: null
  4023: null
  5333: null
  6381: null
  8000: null
  10119: null
  10666: null
  13333: null
  16000: null
  16047: null
  18666: null
  21333: null
  24000: null
  25448: null
  26666: null
  29333: null
  32000: null
  34666: null
  37333: null
  40000: null
  40357: null
  42666: null
  45333: null
  48000: null
  50666: null
  53333: null
  56000: null
  58666: null
  61333: null
  64000: null
device: cpu
entity: devinterp
logging_steps: !!set
  0: null
  1: null
  2: null
  3: null
  4: null
  5: null
  6: null
  7: null
  8: null
  9: null
  10: null
  11: null
  13: null
  14: null
  16: null
  18: null
  20: null
  22:

batch_size: 128
device: cpu
entity: devinterp
num_epochs: 165
num_steps: 64000
num_training_samples: 50000
optimizer_config:
  lr: 0.1
  momentum: 0.9
  optimizer_type: SGD
  weight_decay: 0.0001
project: resnet18
scheduler_config:
  gamma: 0.5
  last_epoch: -1
  milestones:
  - 16000
  - 32000
  - 48000
  scheduler_type: MultiStepLR



In [16]:
def loss_metric(model, data, target, output):
    return F.cross_entropy(output, target, reduction="sum")

def accuracy_metric(model, data, target, output):
    pred = output.argmax(dim=1, keepdim=True)
    return pred.eq(target.view_as(pred)).sum()

def eval_learner(learner: Learner):
    return dataloaders_reduce(learner.model, {"Train": learner.train_loaders, "Test": learner.test_loaders}, {"Loss": loss_metric, "Accuracy": accuracy_metric}, device=learner.config.device)

learner = Learner(model, train_set, test_set, config, metrics=[eval_learner])
learner.train()



0,1
Batch/Loss,▁

0,1
Batch/Loss,7.21299


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016671020416682343, max=1.0…

Epoch 0 Batch 0/64000 Loss: ?.??????:   0%|          | 0/64000 [00:00<?, ?it/s]

# Feature visualization

We have a trained `model` (and a bunch of checkpoints). First, let's do some classic feature visualization on the final network. We'll select a few random neurons from ac

In [65]:
import glob
import logging
import os
import warnings
from pathlib import Path
from typing import Dict, List, Optional, Tuple, TypedDict, Union, Generic, TypeVar, Literal, Set

import boto3
import torch
from botocore.exceptions import ClientError

logger = logging.getLogger(__name__)

IDType = TypeVar("IDType")

class StorageProvider(Generic[IDType]):
    """
    Wrapper for either local or cloud (S3) storage (or both).

    :param bucket_name: The name of the S3 bucket to store checkpoints (Optional)
    :param local_root: If provided, then the base directory in which to save files locally. If omitted, files will not be saved locally. (Optional)
    :param save_locally: If True, saves checkpoints locally without deleting them (Optional)
    """
    file_ids: List[IDType]

    def __init__(self, bucket_name: Optional[str] = None, local_root: Optional[str] = None, parent_dir: str = "data",  device=torch.device("cpu")):
        self.bucket_name = bucket_name
        self.is_local_enabled = local_root is not None
        self.local_root = Path(local_root or "tmp")
        self.parent_dir = parent_dir
        self.device = device

        self.file_ids = [] # Any non-int hashable type.
        self.client = None

        # Cloud 

        if bucket_name and (os.getenv("AWS_SECRET_ACCESS_KEY") and os.getenv("AWS_ACCESS_KEY_ID")):
            self.client = boto3.client("s3")
            self.file_ids = self.get_file_ids()
        else:
            warnings.warn("AWS_SECRET_ACCESS_KEY and AWS_ACCESS_KEY_ID must be set to use S3 bucket.")

        # Local 
        
        local_path = os.path.join(self.local_root, parent_dir)
        if self.is_local_enabled and not os.path.exists(local_path):
            os.makedirs(local_path)
        
        if not self.bucket_name and not self.local_root:
            warnings.warn("Neither S3 bucket name provided nor local_root is defined. Files will not be persisted.")

    @property
    def is_s3_enabled(self):
        return self.client is not None
    
    def id_to_name(self, file_id: Union[IDType, Literal["*"]]) -> str:
        """Should contain no `/` and should handle the wildcard."""
        raise NotImplementedError
    
    def id_to_key(self, file_id: Union[IDType, Literal["*"]]) -> str:
        return f"{self.parent_dir}/{self.id_to_name(file_id)}.pt"
    
    def name_to_id(self, name: str) -> IDType:
        raise NotImplementedError

    def get_file_ids(self) -> List[IDType]:
        """
        Returns a list of tuples (epoch, batch_idx) of all checkpoints in the bucket or local directory.
        """
        file_ids: Set[IDType] = set()

        if self.is_local_enabled:
            files = glob.glob(f"{self.local_root}/{self.id_to_key('*')}")
            file_ids |= {self.name_to_id(os.path.basename(f)) for f in files}

        if self.is_s3_enabled:
            response = self.client.list_objects_v2(Bucket=self.bucket_name)
            if "Contents" in response:
                file_ids |= {self.name_to_id(item["Key"]) for item in response["Contents"] if item["Key"].startswith(self.parent_dir)}
            
        return sorted(list(file_ids))

    def upload_file(self, file_path: str, key: str):
        self.client.upload_file(file_path, self.bucket_name, key)

    def save_file(self, file_id: str, file):
        file_path = self.id_to_key(file_id)
        rel_file_path = self.local_root / file_path
        torch.save(file, rel_file_path)

        if self.client:
            self.upload_file(rel_file_path, file_path)

        if not self.is_local_enabled:
            os.remove(rel_file_path)

    def load_file(self, file_id):
        file_path = self.id_to_key(file_id)
        rel_file_path = self.local_root / file_path

        print(file_id, file_path, rel_file_path)

        if (self.is_local_enabled and os.path.exists(rel_file_path)):
            logger.info(f"Loading {file_path} from local save...")
        elif self.client:
            logger.info(f"Downloading {file_path} from bucket `{self.bucket_name}`...")
            self.client.download_file(self.bucket_name, file_path, rel_file_path)
        else:
            raise OSError(f"File with id `{file_id}` not found either locally or in bucket.")

        checkpoint = torch.load(rel_file_path, map_location=self.device)

        if not self.is_local_enabled and self.bucket_name and self.client:
            os.remove(rel_file_path)

        return checkpoint

    def __iter__(self):
        for file_id in self.file_ids:
            yield self.load_file(file_id)

    def __len__(self):
        return len(self.file_ids)

    def __getitem__(self, idx):
        if isinstance(idx, int):
            return self.load_file(self.file_ids[idx])
        
        elif idx not in self.file_ids:
            warnings.warn(f"File with id `{idx}` not found in {self.bucket_name}.")
            return self.load_file(idx)

        raise TypeError(f"Invalid argument `{idx}` of type `{type(idx)}`")

    def __contains__(self, file_id):
        return file_id in self.file_ids

    def __repr__(self):
        return f"StorageProvider({self.bucket_name}, {self.local_root})"

EpochAndBatch = Tuple[int, int]

class CheckpointManager(StorageProvider[EpochAndBatch]):
    def __init__(self, project_dir: str, bucket_name: Optional[str] = None, local_root: Optional[str] = None,  device=torch.device("cpu")):
        super().__init__(bucket_name, local_root, f"checkpoints/{project_dir}", device=device)

    @staticmethod
    def id_to_name(file_id: Union[EpochAndBatch, Literal["*"]]) -> str:
        if file_id == "*":
            return "*"
        
        epoch, batch = file_id
        return f"checkpoint_epoch_{epoch}_batch_{batch}"

    @staticmethod
    def name_to_id(name: str) -> EpochAndBatch:
        parts = name.split("_")
        epoch = int(parts[-3])
        batch_idx = int(parts[-1].split(".")[0])
        return epoch, batch_idx

    def __repr__(self):
        return f"CheckpointManager({self.parent_dir}, {self.bucket_name})"


NeuronSeedBatch = Tuple[int, int, int]

class VisualizationManager(StorageProvider[NeuronSeedBatch]):
    def __init__(self, project_dir: str, bucket_name: Optional[str] = None, local_root: Optional[str] = None,  device=torch.device("cpu")):
        super().__init__(bucket_name, local_root, f"visualizations/{project_dir}", device=device)

    @staticmethod
    def id_to_name(file_id: NeuronSeedBatch):
        if file_id == "*":
            return "*"
        
        neuron, seed, batch = file_id
        return f"visualization_neuron_{neuron}_seed_{seed}_batch_{batch}"

    @staticmethod
    def name_to_id(name: str) -> NeuronSeedBatch:
        parts = name.split("_")
        neuron = int(parts[-5])
        seed = int(parts[-3])
        batch_idx = int(parts[-1].split(".")[0])
        return neuron, seed, batch_idx

    def __repr__(self):
        return f"VisualizationManager({self.parent_dir}, {self.bucket_name})"

In [66]:
checkpoints = CheckpointManager('ResNet18/CIFAR10', 'devinterp', local_root="..")
model = torchvision.models.resnet18(pretrained=False)
model.load_state_dict(checkpoints[-1]["model"])

INFO:__main__:Loading checkpoints/ResNet18/CIFAR10/checkpoint_epoch_164_batch_64000.pt from local save...


(164, 64000) checkpoints/ResNet18/CIFAR10/checkpoint_epoch_164_batch_64000.pt ../checkpoints/ResNet18/CIFAR10/checkpoint_epoch_164_batch_64000.pt


<All keys matched successfully>

In [46]:
import torch.multiprocessing as mp

def worker(worker_id, viz, device, start, end, queue):
    results = []

    for i, probe in enumerate(viz.activations):
        images = viz.render(
            probe,
            thresholds = thresholds,
            verbose = verbose,
            seed=init_seed + i,
            device=device
        )

        if verbose:
            show_images(*images, **kwargs)

        results.append((images, probe.activation))

    queue.put(results)

class FeatureVisualizer:
    def __init__(self, model: torch.nn.Module, locations: Optional[list[str]]=None):
        self.model = model
        self.locations = locations or self.gen_locations(model)  # Defaults to all neurons in the model
        self.activations = [ActivationProbe(model, location) for location in self.locations]

    @staticmethod
    def gen_locations(model: torch.nn.Module, layer_type: Optional[Union[type, Tuple[type, ...]]] = None) -> list[str]:
        """Generate neurons of a particular kind of layer from a PyTorch model."""
        channel_locations = []

        def recursive_search(module, prefix):
            for name, submodule in module.named_children():
                path = prefix + '.' + name if prefix else name

                if not layer_type or isinstance(submodule, layer_type):

                    # TODO: Get rid of the "weight"
                    if isinstance(submodule, torch.nn.Linear):
                        for feature in range(submodule.out_features):
                            location = f"{path}.weight.{feature}"
                            channel_locations.append(location)

                    elif isinstance(submodule, torch.nn.Conv2d):
                        for channel in range(submodule.out_channels):
                            location = f"{path}.weight.{channel}"
                            channel_locations.append(location)
                    else:
                        warnings.warn(f"Unknown layer type: {type(submodule)}. Skipping")


                recursive_search(submodule, path)

        recursive_search(model, '')

        return channel_locations

    def render(self, probe: ActivationProbe, transform: Optional[Transform] = None, thresholds: list[int]=[512], verbose: bool = True, seed: int = 0, device: torch.device = DEVICE) -> list[torch.Tensor]:
        """Renders an image that maximizes the activation of the specified neuron.
        
        Args:
            transform (transforms.Compose, optional): Image transform to apply during optimization.
            thresholds (list[int], optional): List of iterations at which to save the optimized image.
            verbose (bool, optional): Whether to print progress information.
            seed (int, optional): Random seed for initialization of the input image.
            device (str, optional): Device on which to perform the computation.
            
        Returns:
            tuple[list[torch.Tensor], float]: A tuple containing the final images and the activation value.
        """

        # Assuming 'model' is your pre-trained ResNet model and 'location' is the string specifying the neuron's location
        self.model.to(device)
        self.model.eval()

        with probe.watch():
            # Create a random image (1x3x224x224) to start optimization, with same size as typical ResNet input
            torch.manual_seed(seed)
            input_image = torch.rand((1, 3, 32, 32), requires_grad=True, device=device)

            # Optimizer
            optimizer = torch.optim.Adam([input_image], lr=0.01, weight_decay=1e-3)

            final_images = []

            # Optimization loop
            pbar = range(max(thresholds) + 1)

            if verbose:
                pbar = tqdm(pbar, desc=f"Visualizing {probe.location} (activation: ???)")

            for iteration in pbar:
                optimizer.zero_grad()
                self.model(input_image)  # Forward pass through the model to trigger the hook
                activation = probe.activation
                loss = -activation  # Maximizing activation
                loss.backward()
                optimizer.step()

                if transform:
                    input_image.data = transform(input_image.data.detach().clone())

                if verbose:
                    pbar.set_description(f"Visualizing {probe.location} (activation: {activation.item():.2f})")

                if iteration in thresholds:
                    image = input_image.detach().clone()
                    image = torch.reshape(image, (1, 3, 32, 32))            
                    final_images.append(image)
        
        return final_images

    # def render(self, probe: ActivationProbe, num_visualizations: int, transform: Optional[Transform] = None, thresholds: list[int]=[512], verbose: bool = True, seed: int = 0, device: torch.device = DEVICE, diversity_weight: float = 0.1) -> list[list[torch.Tensor]]:
    #     """Renders multiple images that maximize the activation of the specified neuron, with a penalty to increase diversity.
        
    #     Args:
    #         num_visualizations (int): Number of visualizations to generate.
    #         ... other args same as render ...
    #         diversity_weight (float, optional): Weight of the diversity penalty.
        
    #     Returns:
    #         list[list[torch.Tensor]]: A list containing lists of final images for each visualization.
    #     """

    #     self.model.to(device)
    #     self.model.eval()
        
    #     with probe.watch():

    #         # Create random images (num_visualizations x 3 x 32 x 32) to start optimization
    #         torch.manual_seed(seed)
    #         input_images = torch.rand((num_visualizations, 3, 32, 32), requires_grad=True, device=device)

    #         # Optimizer
    #         optimizer = torch.optim.Adam([input_images], lr=0.01, weight_decay=1e-3)

    #         all_final_images = [[] for _ in range(num_visualizations)]

    #         pbar = range(max(thresholds) + 1)
    #         if verbose:
    #             pbar = tqdm(pbar, desc=f"Visualizing {probe.location} (activation: ???)")

    #         for iteration in pbar:
    #             optimizer.zero_grad()
    #             total_loss = torch.tensor(0.0, device=device)

    #             for i in range(num_visualizations):
    #                 self.model(input_images[i].unsqueeze(0))  # Forward pass through the model to trigger the hook
    #                 total_loss += -probe.activation  # Maximizing activation

    #                 # Calculate diversity penalty for the current image
    #                 for j in range(num_visualizations):
    #                     if i != j:
    #                         diversity_loss = F.cosine_similarity(input_images[i].view(1, -1), input_images[j].view(1, -1))
    #                         total_loss += diversity_weight * diversity_loss

    #             total_loss.backward()
    #             optimizer.step()

    #             if transform:
    #                 input_images.data = transform(input_images.data.detach().clone())

    #             if verbose:
    #                 pbar.set_description(f"Visualizing {probe.location} (activation: {-total_loss.item():.2f})")

    #             if iteration in thresholds:
    #                 for i in range(num_visualizations):
    #                     image = input_images[i].detach().clone()
    #                     image = torch.reshape(image, (3, 32, 32))
    #                     all_final_images[i].append(image)

    #     return all_final_images

    def _render_all(self, thresholds: list[int]=[512], verbose: bool = True, init_seed: int = 0, device: str = "cuda", num_workers: int = 1, **kwargs) -> list[tuple[list[torch.Tensor], float]]:
        results = []

        for i, probe in enumerate(self.activations):
            images = self.render(
                probe,
                thresholds = thresholds,
                verbose = verbose,
                seed=init_seed + i,
                device=device
            )

            if verbose: 
                show_images(*images, **kwargs)

            results.append((images, probe.activation))

        return results

    def render_all(self, thresholds: list[int] = [512], verbose: bool = True, init_seed: int = 0, device: str = "cuda", num_workers: int = 1, **kwargs) -> list[tuple[list[torch.Tensor], float]]:
        if num_workers == 1:
            return self._render_all(thresholds, verbose, init_seed, device, **kwargs)
        
        mp.set_start_method('spawn', force=True)    
        devices = [torch.device(f"cpu:{i}") for i in range(num_workers)]

        queue = mp.Queue()
        processes = []

        # Split the work between the workers
        split_size = len(self) // num_workers
        for worker_id in range(num_workers):
            start_idx = worker_id * split_size
            end_idx = start_idx + split_size if worker_id != num_workers - 1 else len(self)
            p = mp.Process(target=worker, args=(worker_id, viz, devices[worker_id], start_idx, end_idx, queue))
            p.start()
            processes.append(p)

        results = []
        for _ in range(num_workers):
            results += queue.get()

        for p in processes:
            p.join()

        return results
    
    def __len__(self) -> int:
        return len(self.activations)

    def __getitem__(self, idx: Union[int, str]) -> ActivationProbe:
        if isinstance(idx, int):
            return self.activations[idx]
        elif isinstance(idx, str):
            return next(filter(lambda probe: probe.location == idx, self.activations))
        else:
            raise TypeError(f"Invalid type for index: {type(idx)}")

In [47]:
viz = FeatureVisualizer(model, [f'layer1.0.conv1.weight.{i}' for i in range(10)])
image = viz.render_all(thresholds=[0, 64, 128, 192, 255, 511], device="cpu", num_workers=4)

Traceback (most recent call last):
  File "<string>", line 1, in <module>
Traceback (most recent call last):
  File "/usr/local/Cellar/python@3.11/3.11.4_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/spawn.py", line 120, in spawn_main
  File "<string>", line 1, in <module>
  File "/usr/local/Cellar/python@3.11/3.11.4_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/spawn.py", line 120, in spawn_main
        exitcode = _main(fd, parent_sentinel)exitcode = _main(fd, parent_sentinel)

                            ^ ^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^
  File "/usr/local/Cellar/python@3.11/3.11.4_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/spawn.py", line 130, in _main
  File "/usr/local/Cellar/python@3.11/3.11.4_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/spawn.py", line 130, in _main
    self = reduction.pickle.load(from_parent)
    self = reduction.pickle.

KeyboardInterrupt: 

### Maximally active neurons

Let's go through all neurons in the model and rank them by their activation. We will then plot the top 10 most active neurons.

In [36]:
locations = FeatureVisualizer.gen_locations(viz.model, (nn.Linear, nn.Conv2d))
len(locations)

5800

In [27]:
neurons_results =  []

for i in range(0, len(conv_neurons), 10):
    section = conv_neurons[i:i+10]
    print(section)
    _results = render_multiple(model, *section, thresholds=[256], device="cuda:0", verbose=False)
    show_images(*[images[-1] for (images, _) in _results], dpi=50)
    neurons_results.extend(_results)

    if i % 100 == 0:
        print(f"Saving results at {i}")
        torch.save(neurons_results, "../visualizations/restnet-cifar10.pt")
        
        # Print the 100 most activated neurons
        print([(name, activation) for (name, (_, activation)) in sorted(zip(conv_neurons, neurons_results), key=lambda x: x[1][1], reverse=True)[:100]])

['layer1.0.conv1.weight.0', 'layer1.0.conv1.weight.1', 'layer1.0.conv1.weight.2', 'layer1.0.conv1.weight.3', 'layer1.0.conv1.weight.4', 'layer1.0.conv1.weight.5', 'layer1.0.conv1.weight.6', 'layer1.0.conv1.weight.7', 'layer1.0.conv1.weight.8', 'layer1.0.conv1.weight.9']


NameError: name 'render_multiple' is not defined

### Developmental analysis of a sample neuron

In [None]:
sample_neuron = "layer1.1.conv1.weights.7"
viz, activation = render(model, sample_neuron, seed=0)[-1]
print(activation)
show_images(viz)

In [None]:
pbar = tqdm(checkpoints, desc="Looping checkpoints (activation: ???)")
activations = []

for state_dict in pbar:
    model.load_state_dict(state_dict)
    extractor = ActivationExtractor(model, sample_neuron)
    handle = extractor.register_hook()

    model.eval()
    with torch.no_grad():
        model(viz) 
        activations.append(extractor.activation)
    
    pbar.set_description(f"Looping checkpoints (activation: {extractor.activation.item():.2f})")

In [None]:
plt.plot([b for (_, b) in checkpoints.checkpoints][-5:], activations[-5:])
plt.xlabel("Training step")
plt.ylabel("Activation")

In [None]:
# Let's do feature visualization at the very start, at 90  steps (where it reaches a minimum) at 5k steps where it's close to 0, at 8600, at 9000, and at the last step. 

# First let's get the closest checkpoints to these steps

ideal_checkpoint_steps = [90, 5000, 8600, 9000, 9999]

def get_closest_checkpoint(checkpoints: list[tuple[int, int]], step: int) -> int:
    return min([chkpt for chkpt in checkpoints], key=lambda x: abs(x[1] - step))

checkpoint_steps = [get_closest_checkpoint(checkpoints.checkpoints, step) for step in ideal_checkpoint_steps]
checkpoint_steps

In [None]:
for (epoch, batch_idx) in tqdm(checkpoint_steps, desc="Going through checkpoints"):
    model.load_state_dict(checkpoints[(epoch, batch_idx)])
    vizs, activation = render(model, sample_neuron, seed=0, thresholds=[0, 64, 128, 256, 512], verbose=True)
    show_images(*vizs)

### Let's do a whole set of neurons

In [None]:
viz_results = torch.load("../visualizations/restnet-cifar10.pt", map_location=torch.device('cpu'))
viz_results = sorted([(name, a, img) for (img, a), name in zip(viz_results, conv_neurons)], key=lambda x: x[1])

In [None]:
import numpy as np
from matplotlib import pyplot as plt

eps = 1e-4
large_eps = 100

activations = [a for _, a, _ in viz_results]

print("# Negative activations: ", len([a for a in activations if a < 0]))
print("# Zero activations: ", len([a for a in activations if a == 0]))
print("# Insignificant positive activations: ", len([a for a in activations if 0 < a <= eps]))
print("# Moderate positive activations: ", len([a for a in activations if eps < a <= large_eps]))
print("# Large positive activations: ", len([a for a in activations if large_eps < a]))

activations = [a for a in activations if a > eps]

log_bins = np.logspace(np.log10(min(activations)),
                       np.log10(max(activations)), num=10)

# Plotting the histogram
plt.hist(activations, bins=log_bins)
plt.xscale('log') # Optional, if you want the x-axis to be logarithmic


In [None]:
# Choose 5 from each category by random
np.random.seed(2)

sample_neurons = [
    *np.random.choice([n for n, a, _ in viz_results if a < -eps], size=5, replace=False),
    *np.random.choice([n for n, a, _ in viz_results if -eps <= a <= eps], size=5, replace=False),
    *np.random.choice([n for n, a, _ in viz_results if eps < a <= large_eps], size=5, replace=False),
    *np.random.choice([n for n, a, _ in viz_results if large_eps < a], size=5, replace=False),
]
print(sample_neurons)
images = [imgs[-1] for n, _, imgs in viz_results if n in sample_neurons]
show_images(
    *images,
    nrow=5
)

In [None]:
def evolve_multiple(model: nn.Module, checkpoints: CheckpointManager, *locations: str, opt_steps: int = 512, **kwargs):
    model.load_state_dict(checkpoints[-1])
    model.eval()

    final_vizs: dict[str, torch.Tensor] = {}
    vizs: dict[str, list[torch.Tensor]] = {}
    activations: dict[str, list[float]] = {}
   
    # Create the visualizations for the last checkpoint
    for location, _location_vizs in zip(locations, tqdm(render_multiple(model, *locations, thresholds=[opt_steps], **kwargs), desc="Creating initial visualizations")):
        final_vizs[location] = _location_vizs[0][0]
        vizs[location] = []
        activations[location] = []
 
    for i, state_dict in enumerate(tqdm(checkpoints, desc="Visiting checkpoints")):
        batch_idx = checkpoints.checkpoints[i][1]

        # Render the visualization for the next checkpoint
        model.load_state_dict(state_dict)

        for location in locations:
            viz = final_vizs[location]

            model.load_state_dict(state_dict)
            extractor = ActivationExtractor(model, location)
            handle = extractor.register_hook()

            with torch.no_grad():
                model(viz) 
                activations[location].append(extractor.activation.item())

            handle.remove()

        wandb.log({f"Activations/{location}": activations[location][-1] for location in locations}, step=batch_idx, commit=False)
            
        # Visualize this checkpoint
        if i % 20 or i == len(checkpoints) - 1:
            for location, _location_vizs in zip(locations, tqdm(render_multiple(model, *locations, thresholds=[opt_steps], **kwargs), desc=f"Creating visualizations for batch {batch_idx}")):
                viz = _location_vizs[0][0]
                vizs[location].append(viz)
                image_np = gen_images(viz)
                image = wandb.Image(image_np, caption=f"Optimized {location} at batch {batch_idx}")

                wandb.log({f"Visualizations/{location}": image}, step=batch_idx)

    return vizs, activations


In [None]:
# wandb.finish()
# run_id = input("Run ID: ")
wandb.init(project=config.project, entity=config.entity)
results = evolve_multiple(model, checkpoints, *sample_neurons, device="cpu", verbose=False)

In [None]:
['layer3.0.conv2.weight.206', 'layer3.0.conv1.weight.149', 'layer3.0.conv1.weight.118', 'layer3.0.conv2.weight.63', 'layer2.0.conv1.weight.102', 'layer2.0.conv2.weight.110', 'layer2.0.conv1.weight.15', 'layer3.1.conv1.weight.20', 'layer2.0.conv2.weight.19', 'layer3.0.conv1.weight.205', 'layer1.0.conv2.weight.54', 'layer2.0.conv2.weight.12', 'layer2.0.downsample.0.weight.99', 'layer1.0.conv2.weight.0', 'layer1.0.conv1.weight.47', 'layer1.0.conv2.weight.41', 'layer1.0.conv2.weight.51', 'layer3.0.downsample.0.weight.125', 'layer2.0.conv1.weight.108', 'layer1.1.conv2.weight.32']
