# Vision models

## Orientation: ResNet-18


1. Let's start by training a simple ResNet-18 model and take lots of checkpoints.
2. Then do feature visualization on the end results (for a random sample of neurons). 
3. Look at how the activation of the target neuron reacts to those feature visualizations over the course of training.

In [7]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
import torchvision.utils as vutils
from PIL import Image
from tqdm.notebook import tqdm
import os
from dataclasses import dataclass, field
from typing import Optional, Container, Tuple, List, Dict
from dataclasses import asdict
import math
import numpy as np
import wandb 
import pandas as pd
from typing import Callable
import functools
from torch.optim.lr_scheduler import LambdaLR 

from dotenv import load_dotenv
load_dotenv("../.env")

from devinterp.config import Config, OptimizerConfig, SchedulerConfig
from devinterp.checkpoints import CheckpointManager
from devinterp.logging import Logger
from devinterp.data import CustomDataLoader

wandb.finish()

In [8]:
torch.manual_seed(0)
model: nn.Module = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False)

Using cache found in /home/paperspace/.cache/torch/hub/pytorch_vision_v0.10.0


In [10]:
# Data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_transforms = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

train_set = datasets.CIFAR10(root='../data', train=True, download=True, transform=train_transforms)
test_set = datasets.CIFAR10(root='../data', train=False, download=True, transform=transform)

Files already downloaded and verified
Files already downloaded and verified


In [11]:
from devinterp.config import Config
import yaml

config = Config(
    num_training_samples=len(train_set), 
    num_steps=64_000, 
    project="resnet18", 
    entity="devinterp", 
    logging_steps=(100, 100), 
    checkpoint_steps=(25, 25),
    optimizer_config=OptimizerConfig(
        optimizer_type="SGD",
        lr=0.1,
        momentum=0.9,
    ),
    scheduler_config=SchedulerConfig(
        scheduler_type="MultiStepLR",
        milestones=[16_000, 32_000, 48_000], 
        gamma=0.5
    ),
)

print(yaml.dump(config.model_dump(exclude=("logging_steps", "checkpoint_steps"))))

batch_size: 128
device: cuda
entity: devinterp
num_epochs: 165
num_steps: 64000
num_training_samples: 50000
optimizer_config:
  lr: 0.1
  momentum: 0.9
  optimizer_type: SGD
  weight_decay: 0.0001
project: resnet18
scheduler_config:
  gamma: 0.5
  last_epoch: -1
  milestones:
  - 16000
  - 32000
  - 48000
  scheduler_type: MultiStepLR





In [12]:
class LearnerStateDict(TypedDict):
    model: Dict
    optimizer: Dict
    scheduler: Optional[Dict]

class Learner:
    def __init__(self, model: torch.nn.Module, train_set: torch.utils.data.Dataset, test_set: torch.utils.data.DataLoader, config: Config, metrics: Optional[List[Callable[['Learner'], Dict]]]=None):
        self.config = config
        self.model = model
        self.train_set = train_set
        self.test_set = test_set
        self.optimizer = config.optimizer_config.factory(model.parameters())

        def lr_lambda(step: int):
            if step < 400:
                return 0.1
            elif step < 32_000:
                return 1.
            elif step < 48_000:
                return 0.1
            else:
                return 0.01

        self.scheduler = LambdaLR(self.optimizer, lr_lambda=lr_lambda) # config.scheduler_config.factory(self.optimizer)
        self.train_loader = CustomDataLoader(train_set, batch_size=config.batch_size, shuffle=True)
        self.test_loader = torch.utils.data.DataLoader(test_set, batch_size=config.batch_size, shuffle=False)
        self.metrics = metrics or []
        self.logger = Logger(project=config.project, entity=config.entity, logging_steps=config.logging_steps, metrics=[], out_file=None, use_df=False)
        self.checkpoints = CheckpointManager(f"{model.__class__.__name__}18/{self.train_loader.dataset.__class__.__name__}", 'devinterp')  # TODO: read 18 automatically
        
    def measure(self):
        return functools.reduce(lambda x, y: x | y, [metric(self) for metric in self.metrics], {})

    def resume(self, batch_idx: Optional[int] = None):
        if batch_idx is None:
            epoch, batch_idx = self.checkpoints[-1]
        else:
            epoch, batch = min(self.checkpoints, key=lambda x: abs(x[1] - batch_idx))

            if batch != batch_idx:
                warnings.warn(f"Could not find checkpoint with batch_idx {batch_idx}. Resuming from closest batch ({batch}) instead.")

        self.load(epoch, batch_idx)

    def train(self, resume=False, run_id: Optional[str] = None):
        if resume:
            self.resume(resume, run_id)

            if self.scheduler:
                self.scheduler.last_epoch = self.config.num_steps_per_epoch * epoch + batch_idx

        self.model.to(self.config.device)
        self.model.train()

        if self.config.is_wandb_enabled:
            if resume and not run_id:
                warnings.warn("Resuming from checkpoint but no run_id provided. Will not log to existing wandb run.")
            
            if not run_id:
                wandb.init(project=config.project, entity=config.entity)
            else:
                wandb.init(project=config.project, entity=config.entity, run_id=run_id)

        pbar = tqdm(total=self.config.num_steps, desc=f"Epoch 0 Batch 0/{self.config.num_steps} Loss: ?.??????")
        
        for epoch in range(0, self.config.num_epochs):
            self.set_seed(epoch)

            for _batch_idx, (data, target) in enumerate(self.train_loader):
                batch_idx = self.config.num_steps_per_epoch * epoch + _batch_idx
                data, target = data.to(self.config.device), target.to(self.config.device)
                self.optimizer.zero_grad()
                output = self.model(data)
                loss = F.cross_entropy(output, target)
                loss.backward()
                self.optimizer.step()

                if self.scheduler:
                    self.scheduler.step()

                # Update progress bar description
                pbar.set_description(f"Epoch {epoch} Batch {batch_idx}/{self.config.num_steps} Loss: {loss.item():.6f}")
                pbar.update(1)

                if self.config.is_wandb_enabled:
                    # TODO: Figure out how to make this work with Logger
                    wandb.log({"Batch/Loss": loss.item()}, step=batch_idx)

                # Log to wandb & save checkpoints according to log_steps
                if batch_idx in self.config.checkpoint_steps:
                    self.save_checkpoint(epoch, batch_idx)

                if batch_idx in self.config.logging_steps:
                    self.logger.log(self.measure(), step=batch_idx)
                    self.model.train()

            pbar.close()

        if self.config.is_wandb_enabled:
            wandb.finish()

    def state_dict(self) -> LearnerStateDict:
        return {
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict() if scheduler is not None else None,
        }

    def load_state_dict(self, checkpoint: LearnerStateDict):
        self.model.load_state_dict(checkpoint["model"])
        self.optimizer.load_state_dict(checkpoint["optimizer"])

        if self.scheduler is not None and checkpoint["scheduler"] is not None:
            self.scheduler.load_state_dict(checkpoint["scheduler"])
    
    def save_checkpoint(self, epoch: int, batch_idx: int):
        checkpoint = self.state_dict()
        self.checkpoints.save_checkpoint(checkpoint, epoch, batch_idx)

    def load_checkpoint(self, epoch: int, batch_idx: int):
        checkpoint = self.checkpoints.load_checkpoint(epoch, batch_idx)
        self.load_state_dict(checkpoint)

    def set_seed(self, seed: int):
        np.random.seed(epoch)
        torch.manual_seed(epoch)
        random.seed(epoch)
        self.train_loader.shuffle_data(seed=epoch)

        if "cuda" in str(self.config.device):
            torch.cuda.manual_seed_all(epoch) 
    

In [13]:
def eval_model(model: torch.nn.Module, loader: torch.utils.data.DataLoader, config: Config):
    loss = torch.zeros(1, device=config.device)
    correct = torch.zeros(1, device=config.device)
    total = torch.zeros(1, device=config.device)

    model.eval()
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(config.device), target.to(config.device)
            output = model(data)
            loss += F.cross_entropy(output, target, reduction="sum")
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += len(data)
        
    loss /= total
    accuracy = correct / total

    return loss.item(), accuracy.item()


def eval_learner(learner: Learner):
    train_loss, train_accuracy = eval_model(learner.model, learner.train_loader, learner.config)
    test_loss, test_accuracy = eval_model(learner.model, learner.test_loader, learner.config)

    return {
        "Train/Loss": train_loss,
        "Train/Accuracy": train_accuracy,
        "Test/Loss": test_loss,
        "Test/Accuracy": test_accuracy,
    }

learner = Learner(model, train_set, test_set, config, metrics=[eval_learner])
learner.train()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjqhoogland[0m ([33mdevinterp[0m). Use [1m`wandb login --relogin`[0m to force relogin


Epoch 0 Batch 0/64000 Loss: ?.??????:   0%|          | 0/64000 [00:00<?, ?it/s]

# Feature visualization

We have a trained `model` (and a bunch of checkpoints). First, let's do some classic feature visualization on the final network. We'll select a few random neurons from ac

In [None]:
checkpoints = CheckpointManager('resnet18/cifar10', 'devinterp')


In [None]:
model = torchvision.models.resnet18(pretrained=False)
model.load_state_dict(checkpoints[-1])
# model: nn.Module = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
# model: nn.Module = torch.hub.load('pytorch/vision:v0.10.0', 'inception_v1', pretrained=True)
model.state_dict().keys()

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt

class ActivationExtractor:
    """
    With this version, you can use the ActivationExtractor with the following location formats:

    - 'layer1.0.conv1.weight.3': Only channel is specified; y and x default to center.
    - 'layer1.0.conv1.weight.3.2': Channel and y are specified; x defaults to center.
    - 'layer1.0.conv1.weight.3..2': Channel and x are specified; y defaults to center.
    - 'layer1.0.conv1.weight.3.2.2': Channel, y, and x are all specified.
    
    """
    
    def __init__(self, model, location):
        self.activation = None
        self.model = model
        self.location = location.split('.')
        self.layer_path = []
        self.channel = None
        self.y = None
        self.x = None

        # Split the location into layer path and neuron indices
        state_dict_keys = list(model.state_dict().keys())
        for part in self.location:
            self.layer_path.append(part)
            path = '.'.join(self.layer_path)
            
            if any(key.startswith(path) for key in state_dict_keys):
                continue
            else:
                self.layer_path.pop()
                self.channel, *yx = map(int, self.location[len(self.layer_path):])
                if yx:
                    self.y = yx[0]
                    if len(yx) > 1:
                        self.x = yx[1]
                break

        # Get the target layer
        self.layer = model
        for part in self.layer_path[:-1]:
            self.layer = getattr(self.layer, part)

    def hook_fn(self, module, input, output):
        y = self.y if self.y is not None else output.size(2) // 2
        x = self.x if self.x is not None else output.size(3) // 2

        self.activation = output[0, self.channel, y, x]

    def register_hook(self):
        handle = self.layer.register_forward_hook(self.hook_fn)
        return handle


def gen_image(image: torch.Tensor):
    # Process the optimized input
    image = image.detach().cpu().squeeze(0)

    image -= image.min()
    image /= image.max()

    # Create grid
    grid_image = vutils.make_grid([image], nrow=1)

    # Convert to numpy and transpose for plotting
    grid_image_np = grid_image.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()

    return grid_image_np

def show_image(image: torch.Tensor):
    # Convert to numpy and transpose for plotting
    grid_image_np = gen_image(image)

    # Display using matplotlib
    plt.figure(figsize=(5, 5))  # You can change the size as you prefer
    plt.imshow(grid_image_np)
    plt.axis('off') # to remove the axis
    plt.show()

def show_images(*images: torch.Tensor, nrow=None, **kwargs):
    # Normalize images to [0,1] and create grid
    images = [img - img.min() for img in images]
    images = [img / img.max() for img in images]
    images = [img.squeeze(0) for img in images]
    
    # Create grid
    grid_image = vutils.make_grid(images, nrow=nrow or len(images))

    # Convert to numpy and transpose for plotting
    grid_image_np = grid_image.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()

    # Display using matplotlib
    plt.figure(figsize=(15, 30), **kwargs)  # You can change the size as you prefer
    plt.imshow(grid_image_np)
    plt.axis('off') # to remove the axis
    plt.show()


def add_jitter(input_image, jitter_amount=2):
    """Applies jitter by randomly shifting the image."""
    if not jitter_amount:
        return input_image

    x_shift, y_shift = torch.randint(jitter_amount, -jitter_amount, (2,))
    return torch.roll(input_image, shifts=(x_shift, y_shift), dims=(2, 3))

def render(model: nn.Module, location: str, thresholds: list[int]=[512], verbose: bool = True, seed: int = 0, device: str = torch.device) -> tuple[list[torch.Tensor], float]:
    # Assuming 'model' is your pre-trained ResNet model and 'location' is the string specifying the neuron's location
    model.to(device)
    model = model.eval()
    extractor = ActivationExtractor(model, location)
    handle = extractor.register_hook()

    # Create a random image (1x3x224x224) to start optimization, with same size as typical ResNet input
    torch.manual_seed(seed)
    input_image = torch.rand((1, 3, 32, 32), requires_grad=True, device=device)

    # Optimizer
    optimizer = optim.Adam([input_image], lr=0.01, weight_decay=1e-3)
    jitter_amount = 0

    final_images = []

    # Optimization loop
    pbar = range(max(thresholds) + 1)

    if verbose:
        pbar = tqdm(pbar, desc=f"Visualizing {location} (activation: ???)")

    for iteration in pbar:
        optimizer.zero_grad()
        model(input_image)  # Forward pass through the model to trigger the hook
        activation = extractor.activation
        loss = -activation  # Maximizing activation
        loss.backward()
        optimizer.step()

        input_image.data = add_jitter(input_image.data.detach().clone(), jitter_amount=-jitter_amount)

        if verbose:
            pbar.set_description(f"Visualizing {location} (activation: {activation.item():.2f})")

        if iteration in thresholds:
            # if verbose:
            #     show_image(input_image)

            image = input_image.detach().clone()
            image = torch.reshape(image, (1, 3, 32, 32))            
            final_images.append(image)

    handle.remove()  # Remove the hook after the loop

    return final_images, extractor.activation.item()


def render_multiple(model: nn.Module, *locations: str, thresholds: list[int]=[512], verbose: bool = True, init_seed: int = 0, device: str = "cuda", **kwargs) -> list[tuple[list[torch.Tensor], float]]:
    results = []

    for i, location in enumerate(locations):
        images, activation = render(
            model, 
            location = location,
            thresholds = thresholds,
            verbose = verbose,
            seed=init_seed + i,
            device=device
        )

        if verbose: 
            show_images(*images, **kwargs)

        results.append((images, activation))

    return results

In [None]:
results = render_multiple(
    model,    
    'layer1.0.conv1.weight.0',
    'layer1.0.conv2.weight.1',
    'layer1.1.conv1.weight.7',
    'layer1.1.conv2.weight.4',
    'layer2.0.conv1.weight.3',
    'layer2.0.conv2.weight.2',
    thresholds=[0, 64, 128, 256, 512],
    verbose=False,
    device="cuda:0"
)
show_images(*[images[-1] for (images, _) in results], dpi=50)

### Maximally active neurons

Let's go through all neurons in the model and rank them by their activation. We will then plot the top 10 most active neurons.

In [None]:
from torch.nn import Conv2d

def gen_conv_neurons(model: nn.Module):
    """Generate convolutional neurons from a PyTorch model."""
    channel_locations = []

    def recursive_search(module, prefix):
        for name, submodule in module.named_children():
            path = prefix + '.' + name if prefix else name

            # Check if the submodule is a convolutional layer
            if isinstance(submodule, Conv2d):
                # Generate locations for all channels in this convolutional layer
                for channel in range(submodule.out_channels):
                    location = f"{path}.weight.{channel}"
                    channel_locations.append(location)

            # Recursively search through children
            recursive_search(submodule, path)

    recursive_search(model, '')

    return channel_locations

conv_neurons = gen_conv_neurons(model)[64:]
print(conv_neurons)

In [None]:
neurons_results =  []

for i in range(0, len(conv_neurons), 10):
    section = conv_neurons[i:i+10]
    print(section)
    _results = render_multiple(model, *section, thresholds=[256], device="cuda:0", verbose=False)
    show_images(*[images[-1] for (images, _) in _results], dpi=50)
    neurons_results.extend(_results)

    if i % 100 == 0:
        print(f"Saving results at {i}")
        torch.save(neurons_results, "../visualizations/restnet-cifar10.pt")
        
        # Print the 100 most activated neurons
        print([(name, activation) for (name, (_, activation)) in sorted(zip(conv_neurons, neurons_results), key=lambda x: x[1][1], reverse=True)[:100]])

### Developmental analysis of a sample neuron

In [None]:
sample_neuron = "layer1.1.conv1.weights.7"
viz, activation = render(model, sample_neuron, seed=0)[-1]
print(activation)
show_image(viz)

In [None]:
pbar = tqdm(checkpoints, desc="Looping checkpoints (activation: ???)")
activations = []

for state_dict in pbar:
    model.load_state_dict(state_dict)
    extractor = ActivationExtractor(model, sample_neuron)
    handle = extractor.register_hook()

    model.eval()
    with torch.no_grad():
        model(viz) 
        activations.append(extractor.activation)
    
    pbar.set_description(f"Looping checkpoints (activation: {extractor.activation.item():.2f})")

In [None]:
plt.plot([b for (_, b) in checkpoints.checkpoints][-5:], activations[-5:])
plt.xlabel("Training step")
plt.ylabel("Activation")

In [None]:
# Let's do feature visualization at the very start, at 90  steps (where it reaches a minimum) at 5k steps where it's close to 0, at 8600, at 9000, and at the last step. 

# First let's get the closest checkpoints to these steps

ideal_checkpoint_steps = [90, 5000, 8600, 9000, 9999]

def get_closest_checkpoint(checkpoints: list[tuple[int, int]], step: int) -> int:
    return min([chkpt for chkpt in checkpoints], key=lambda x: abs(x[1] - step))

checkpoint_steps = [get_closest_checkpoint(checkpoints.checkpoints, step) for step in ideal_checkpoint_steps]
checkpoint_steps

In [None]:
for (epoch, batch_idx) in tqdm(checkpoint_steps, desc="Going through checkpoints"):
    model.load_state_dict(checkpoints[(epoch, batch_idx)])
    vizs, activation = render(model, sample_neuron, seed=0, thresholds=[0, 64, 128, 256, 512], verbose=True)
    show_images(*vizs)

### Let's do a whole set of neurons

In [None]:
viz_results = torch.load("../visualizations/restnet-cifar10.pt", map_location=torch.device('cpu'))
viz_results = sorted([(name, a, img) for (img, a), name in zip(viz_results, conv_neurons)], key=lambda x: x[1])

In [None]:
import numpy as np
from matplotlib import pyplot as plt

eps = 1e-4
large_eps = 100

activations = [a for _, a, _ in viz_results]

print("# Negative activations: ", len([a for a in activations if a < 0]))
print("# Zero activations: ", len([a for a in activations if a == 0]))
print("# Insignificant positive activations: ", len([a for a in activations if 0 < a <= eps]))
print("# Moderate positive activations: ", len([a for a in activations if eps < a <= large_eps]))
print("# Large positive activations: ", len([a for a in activations if large_eps < a]))

activations = [a for a in activations if a > eps]

log_bins = np.logspace(np.log10(min(activations)),
                       np.log10(max(activations)), num=10)

# Plotting the histogram
plt.hist(activations, bins=log_bins)
plt.xscale('log') # Optional, if you want the x-axis to be logarithmic


In [None]:
# Choose 5 from each category by random
np.random.seed(2)

sample_neurons = [
    *np.random.choice([n for n, a, _ in viz_results if a < -eps], size=5, replace=False),
    *np.random.choice([n for n, a, _ in viz_results if -eps <= a <= eps], size=5, replace=False),
    *np.random.choice([n for n, a, _ in viz_results if eps < a <= large_eps], size=5, replace=False),
    *np.random.choice([n for n, a, _ in viz_results if large_eps < a], size=5, replace=False),
]
print(sample_neurons)
images = [imgs[-1] for n, _, imgs in viz_results if n in sample_neurons]
show_images(
    *images,
    nrow=5
)

In [None]:
def evolve_multiple(model: nn.Module, checkpoints: CheckpointManager, *locations: str, opt_steps: int = 512, **kwargs):
    model.load_state_dict(checkpoints[-1])
    model.eval()

    final_vizs: dict[str, torch.Tensor] = {}
    vizs: dict[str, list[torch.Tensor]] = {}
    activations: dict[str, list[float]] = {}
   
    # Create the visualizations for the last checkpoint
    for location, _location_vizs in zip(locations, tqdm(render_multiple(model, *locations, thresholds=[opt_steps], **kwargs), desc="Creating initial visualizations")):
        final_vizs[location] = _location_vizs[0][0]
        vizs[location] = []
        activations[location] = []
 
    for i, state_dict in enumerate(tqdm(checkpoints, desc="Visiting checkpoints")):
        batch_idx = checkpoints.checkpoints[i][1]

        # Render the visualization for the next checkpoint
        model.load_state_dict(state_dict)

        for location in locations:
            viz = final_vizs[location]

            model.load_state_dict(state_dict)
            extractor = ActivationExtractor(model, location)
            handle = extractor.register_hook()

            with torch.no_grad():
                model(viz) 
                activations[location].append(extractor.activation.item())

            handle.remove()

        wandb.log({f"Activations/{location}": activations[location][-1] for location in locations}, step=batch_idx, commit=False)
            
        # Visualize this checkpoint
        if i % 20 or i == len(checkpoints) - 1:
            for location, _location_vizs in zip(locations, tqdm(render_multiple(model, *locations, thresholds=[opt_steps], **kwargs), desc=f"Creating visualizations for batch {batch_idx}")):
                viz = _location_vizs[0][0]
                vizs[location].append(viz)
                image_np = gen_image(viz)
                image = wandb.Image(image_np, caption=f"Optimized {location} at batch {batch_idx}")

                wandb.log({f"Visualizations/{location}": image}, step=batch_idx)

    return vizs, activations


In [None]:
# wandb.finish()
# run_id = input("Run ID: ")
wandb.init(project=config.project, entity=config.entity)
results = evolve_multiple(model, checkpoints, *sample_neurons, device="cpu", verbose=False)

In [None]:
['layer3.0.conv2.weight.206', 'layer3.0.conv1.weight.149', 'layer3.0.conv1.weight.118', 'layer3.0.conv2.weight.63', 'layer2.0.conv1.weight.102', 'layer2.0.conv2.weight.110', 'layer2.0.conv1.weight.15', 'layer3.1.conv1.weight.20', 'layer2.0.conv2.weight.19', 'layer3.0.conv1.weight.205', 'layer1.0.conv2.weight.54', 'layer2.0.conv2.weight.12', 'layer2.0.downsample.0.weight.99', 'layer1.0.conv2.weight.0', 'layer1.0.conv1.weight.47', 'layer1.0.conv2.weight.41', 'layer1.0.conv2.weight.51', 'layer3.0.downsample.0.weight.125', 'layer2.0.conv1.weight.108', 'layer1.1.conv2.weight.32']
