# Vision models

## Orientation: ResNet-18


1. Let's start by training a simple ResNet-18 model and take lots of checkpoints.
2. Then do feature visualization on the end results (for a random sample of neurons). 
3. Look at how the activation of the target neuron reacts to those feature visualizations over the course of training.

In [48]:
import torch
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from tqdm.notebook import tqdm
import os
from dataclasses import dataclass, field
from typing import Optional, Container, Tuple
from dataclasses import asdict
import math
import numpy as np

In [2]:
torch.manual_seed(0)
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False)

Using cache found in /Users/Jesse/.cache/torch/hub/pytorch_vision_v0.10.0


In [3]:
# Data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_set = datasets.CIFAR10(root='../data', train=True, download=True, transform=transform)
test_set = datasets.CIFAR10(root='../data', train=False, download=True, transform=transform)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ../data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:23<00:00, 7285702.43it/s] 


Extracting ../data/cifar-10-python.tar.gz to ../data
Files already downloaded and verified


In [55]:
@dataclass
class Config:
    batch_size: int = 128
    lr: float = 0.01  # Starting lr
    weight_decay: float = 0.0001
    num_epochs: Optional[int] = None
    logging_steps: Optional[Container] = None
    project: Optional[str] = None
    entity: Optional[str] = None
    device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    betas: Tuple[float, float] = (0.9, 0.999)

    def __post_init__(self):
        if self.num_epochs is None:
            # Default to 64k steps
            self.num_epochs = 64000 * self.batch_size // len(train_set)

        if self.logging_steps is None or isinstance(self.logging_steps, int):
            logging_steps = self.logging_steps
            # By default: 1x per epoch
            self.logging_steps = set([i * len(train_set) // self.batch_size for i in range(self.num_epochs)]) 

            if isinstance(logging_steps, int):
                # Logscale from with self.logging_steps steps between 0 and num_epochs * len(train_data) // self.batch_size
                self.logging_steps |= {int(i) for i in np.logspace(0, np.log10(self.num_epochs * len(train_set) // self.batch_size), logging_steps)}
                
config = Config(project="resnet18", entity="devinterp", logging_steps=100)
steps = sorted(list(config.logging_steps))
config

Config(batch_size=128, lr=0.01, weight_decay=0.0001, num_epochs=163, logging_steps={0, 1, 37890, 2, 3, 4, 23046, 5, 6, 60937, 7, 8203, 8, 46093, 9, 10, 11, 521, 31250, 13, 14, 16, 16406, 18, 54296, 20, 1562, 22, 25, 39453, 28, 31, 24609, 35, 62500, 9765, 1019, 39, 47656, 32812, 44, 3117, 17968, 49, 14897, 55859, 3125, 41015, 55, 1594, 26171, 62, 11328, 49218, 36418, 69, 34375, 583, 19531, 57421, 78, 4687, 42578, 27734, 87, 12890, 50781, 35937, 97, 21093, 58984, 6250, 56939, 44140, 109, 29296, 1140, 14453, 52343, 122, 37500, 22656, 60546, 7812, 45703, 136, 11914, 30859, 652, 16015, 53906, 1171, 39062, 152, 24218, 62109, 6814, 9375, 47265, 32421, 17578, 170, 55468, 2734, 40625, 25781, 2229, 63670, 10937, 48828, 190, 33984, 19140, 18628, 57031, 4296, 42187, 27343, 12500, 213, 50390, 729, 35546, 20703, 58593, 5859, 2787, 43750, 50920, 28906, 14062, 238, 51953, 37109, 1783, 22265, 1275, 60156, 7421, 23292, 45312, 30468, 4358, 15625, 4873, 53515, 266, 781, 38671, 16658, 40723, 23828, 61718, 

In [56]:
def maybe_initialize_wandb(project_name: Optional[str] = None, entity: Optional[str] = None):
    if project_name:
        import wandb
        wandb.init(project=project_name, entity=entity)
        return wandb
    return None

wandb = maybe_initialize_wandb(config.project, config.entity)

VBox(children=(Label(value='0.001 MB of 0.003 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.327687…

0,1
Batch/Loss,▆▆▃▃█▆▅▃▃▄▅▃▃▃▁▃▄▂▃▃▃▂▃▃▃▂▂▃▅▂▃▄▂▃▂▃▃▂▂▃

0,1
Batch/Loss,1.45581


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.0167448923670842, max=1.0))…

In [57]:
from torch.optim.lr_scheduler import LambdaLR
from torch import optim

optimizer = optim.Adam(model.parameters(), lr=config.lr, weight_decay=config.weight_decay, betas=config.betas)
scheduler = LambdaLR(optimizer, milestones=[32_000, 48_000], gamma=0.1)  # Same as in the paper 

torch.manual_seed(1)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=config.batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=config.batch_size, shuffle=False)

TypeError: LambdaLR.__init__() got an unexpected keyword argument 'milestones'

In [58]:
def train(model: torch.nn.Module, train_loader: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, scheduler, logging_steps: set,device: torch.device, num_epochs=10, project="resnet18", **kwargs):
    model.train()

    # Calculate total number of batches
    num_batches_per_epoch = len(train_loader)
    total_batches = num_batches_per_epoch * num_epochs

    pbar = tqdm(total=total_batches, desc=f"Epoch 0 Batch 0/{total_batches} Loss: ?.??????")
    
    for epoch in range(1, num_epochs + 1):
        for _batch_idx, (data, target) in enumerate(train_loader, 1):  # Start batch_idx from 1
            batch_idx = num_batches_per_epoch * epoch + _batch_idx
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            # Update progress bar description
            pbar.set_description(f"Epoch {epoch} Batch {batch_idx}/{total_batches} Loss: {loss.item():.6f}")
            pbar.update(1)
            
            if wandb:
                wandb.log({"Batch/Loss": loss.item()}, step=batch_idx)

            # Log to wandb & save checkpoints according to log_steps
            if batch_idx in logging_steps:
                torch.save(model.state_dict(), f"../checkpoints/{project}/checkpoint_epoch_{epoch}_batch_{batch_idx}.pt")
        
        pbar.close()


In [54]:
train(model, train_loader, optimizer, scheduler, **asdict(config))

Epoch 0 Batch 0/63733 Loss: ?.??????:   0%|          | 0/63733 [00:00<?, ?it/s]

KeyboardInterrupt: 

# Feature visualization

We have a trained `model` (and a bunch of checkpoints). First, let's do some classic feature visualization on the final network. We'll select a few random neurons from ac