In [None]:
import wandb

# Initialize wandb
wandb.init(project="resnet34-cifar10-feature-maps", config={"epochs": 10, "batch_size": 64, "learning_rate": 0.001})

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet34
import wandb
import matplotlib.pyplot as plt
import numpy as np

# Configuration
config = wandb.config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define transformations
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])

# Load CIFAR-10 dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=config.batch_size, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

# Define model, loss function, and optimizer
model = resnet34(pretrained=False, num_classes=10)
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)

# Helper functions to log kernels to wandb
def log_kernels_to_wandb(layer, layer_name, epoch):
    kernels = layer.weight.data.cpu().numpy()
    num_kernels = kernels.shape[0]
    num_channels = kernels.shape[1]

    # Normalize kernels to [0, 1] for visualization
    kernels = (kernels - kernels.min()) / (kernels.max() - kernels.min())

    # Calculate grid size
    num_cols = int(np.ceil(np.sqrt(num_kernels)))
    num_rows = int(np.ceil(num_kernels / num_cols))

    fig, axes = plt.subplots(num_rows, num_cols, figsize=(12, 12))
    axes = axes.flatten()  # Flatten to easily iterate and avoid index issues

    for i in range(num_kernels):
        kernel = kernels[i]
        if num_channels == 3:  # RGB kernel
            kernel = kernel.transpose(1, 2, 0)
        axes[i].imshow(kernel, cmap="viridis")
        axes[i].axis("off")

    # Turn off any remaining empty subplots
    for j in range(i + 1, len(axes)):
        axes[j].axis("off")

    fig.suptitle(f"{layer_name} Kernels at Epoch {epoch}")
    wandb.log({f"{layer_name}_kernels": wandb.Image(fig)})
    plt.close(fig)

# Helper function to visualize and log activation maps to wandb
def visualize_activation_maps(model, layer_name, input_image, epoch):
    activation = {}

    def hook_fn(module, input, output):
        activation[layer_name] = output.detach()

    layer = dict([*model.named_modules()])[layer_name]
    hook = layer.register_forward_hook(hook_fn)

    model.eval()
    with torch.no_grad():
        _ = model(input_image.unsqueeze(0).to(device))

    hook.remove()

    # Get the activation maps and check shape
    act_maps = activation[layer_name].squeeze().cpu()

    # Ensure act_maps has at least 2 dimensions (height, width) per map
    if act_maps.dim() == 1:
        # If it’s 1D, there's nothing to display as an image
        print(f"Layer {layer_name} produced 1D outputs; cannot display.")
        return
    elif act_maps.dim() == 2:
        # If 2D, we assume it's already in [height, width] format for a single map
        act_maps = act_maps.unsqueeze(0)  # Add dimension to handle it as a single-channel map

    num_maps = act_maps.size(0)  # Number of channels or feature maps
    num_cols = int(np.ceil(np.sqrt(num_maps)))
    num_rows = int(np.ceil(num_maps / num_cols))

    fig, axes = plt.subplots(num_rows, num_cols, figsize=(12, 12))
    axes = axes.flatten()

    for i in range(num_maps):
        axes[i].imshow(act_maps[i], cmap="viridis")
        axes[i].axis("off")

    # Turn off any unused axes
    for j in range(i + 1, len(axes)):
        axes[j].axis("off")

    fig.suptitle(f"{layer_name} Activation Maps at Epoch {epoch}")
    wandb.log({f"{layer_name}_activation_maps": wandb.Image(fig)})
    plt.close(fig)


# Training loop
for epoch in range(config.epochs):
    model.train()
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(trainloader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        # Log every 100 mini-batches
        if i % 100 == 99:
            avg_loss = running_loss / 100
            wandb.log({"epoch": epoch + 1, "loss": avg_loss})
            print(f"[Epoch {epoch+1}, Batch {i+1}] Loss: {avg_loss:.4f}")
            running_loss = 0.0

    # Save and log kernels for low-level and high-level features
    if epoch % 1 == 0:  # Log every epoch
        log_kernels_to_wandb(model.conv1, "Low_Level_Features", epoch + 1)
        log_kernels_to_wandb(model.layer4[0].conv1, "High_Level_Features", epoch + 1)  # Example of a high-level layer

    # Visualize and log activation maps for a sample image

    # Visualize and log activation maps for a sample image
    sample_img = sample_img.to(device)

    visualize_activation_maps(model, "layer1.0.conv1", sample_img, epoch + 1)  # Example low-level activation map
    visualize_activation_maps(model, "layer4.0.conv1", sample_img, epoch + 1)  # Example high-level activation map

# Testing the model and logging accuracy
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in testloader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
wandb.log({"test_accuracy": accuracy})
print("Test Accuracy: {:.2f}%".format(accuracy))

wandb.finish()


0,1
epoch,▁▁▁▁▁▁▁
loss,█▅▄▃▂▁▁

0,1
epoch,1.0
loss,1.4078


Files already downloaded and verified
Files already downloaded and verified
[Epoch 1, Batch 100] Loss: 2.0758
[Epoch 1, Batch 200] Loss: 1.7921
[Epoch 1, Batch 300] Loss: 1.6708
[Epoch 1, Batch 400] Loss: 1.6245
[Epoch 1, Batch 500] Loss: 1.5623
[Epoch 1, Batch 600] Loss: 1.4773
[Epoch 1, Batch 700] Loss: 1.4565
Layer layer4.0.conv1 produced 1D outputs; cannot display.
[Epoch 2, Batch 100] Loss: 1.3883
[Epoch 2, Batch 200] Loss: 1.3265
[Epoch 2, Batch 300] Loss: 1.2894
[Epoch 2, Batch 400] Loss: 1.2392
[Epoch 2, Batch 500] Loss: 1.2242
[Epoch 2, Batch 600] Loss: 1.1964
[Epoch 2, Batch 700] Loss: 1.1686
Layer layer4.0.conv1 produced 1D outputs; cannot display.
[Epoch 3, Batch 100] Loss: 1.1279
[Epoch 3, Batch 200] Loss: 1.0594
[Epoch 3, Batch 300] Loss: 1.0989
[Epoch 3, Batch 400] Loss: 1.0933
[Epoch 3, Batch 500] Loss: 1.0567
[Epoch 3, Batch 600] Loss: 1.0670
[Epoch 3, Batch 700] Loss: 1.0470
Layer layer4.0.conv1 produced 1D outputs; cannot display.
[Epoch 4, Batch 100] Loss: 1.0446
[E