<a href="https://colab.research.google.com/github/protagora/learnable-activation-function/blob/dev/visualize_kernels.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import wandb

# Initialize wandb
wandb.init(project="resnet34-cifar10", config={"epochs": 10, "batch_size": 64, "learning_rate": 0.001})

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet34
import wandb
import matplotlib.pyplot as plt
import numpy as np

# Configuration
config = wandb.config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define transformations
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])

# Load CIFAR-10 dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=config.batch_size, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

# Define model, loss function, and optimizer
model = resnet34(pretrained=False, num_classes=10)
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)


def log_kernels_to_wandb(layer, layer_name, epoch):
    kernels = layer.weight.data.cpu().numpy()
    num_kernels = kernels.shape[0]
    num_channels = kernels.shape[1]

    # Normalize kernels to [0, 1] for visualization
    kernels = (kernels - kernels.min()) / (kernels.max() - kernels.min())

    # Calculate grid size
    num_cols = int(np.ceil(np.sqrt(num_kernels)))
    num_rows = int(np.ceil(num_kernels / num_cols))

    fig, axes = plt.subplots(num_rows, num_cols, figsize=(12, 12))
    axes = axes.flatten()  # Flatten to easily iterate and avoid index issues

    for i in range(num_kernels):
        kernel = kernels[i]
        if num_channels == 3:  # RGB kernel
            kernel = kernel.transpose(1, 2, 0)
        axes[i].imshow(kernel, cmap="viridis")
        axes[i].axis("off")

    # Turn off any remaining empty subplots
    for j in range(i + 1, len(axes)):
        axes[j].axis("off")

    fig.suptitle(f"{layer_name} Kernels at Epoch {epoch}")
    wandb.log({f"{layer_name}_kernels": wandb.Image(fig)})
    plt.close(fig)

# Training loop
for epoch in range(config.epochs):
    model.train()
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(trainloader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        # Log every 100 mini-batches
        if i % 100 == 99:
            avg_loss = running_loss / 100
            wandb.log({"epoch": epoch + 1, "loss": avg_loss})
            print(f"[Epoch {epoch+1}, Batch {i+1}] Loss: {avg_loss:.4f}")
            running_loss = 0.0

    # Save and log kernels for low-level and high-level features
    if epoch % 1 == 0:  # Log every epoch
        log_kernels_to_wandb(model.conv1, "Low_Level_Features", epoch + 1)
        log_kernels_to_wandb(model.layer3[0].conv1, "High_Level_Features", epoch + 1)  # Example of a high-level layer

# Testing the model and logging accuracy
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in testloader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
wandb.log({"test_accuracy": accuracy})
print("Test Accuracy: {:.2f}%".format(accuracy))

wandb.finish()


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:15<00:00, 11.1MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified




[Epoch 1, Batch 100] Loss: 2.0762
[Epoch 1, Batch 200] Loss: 1.7946
[Epoch 1, Batch 300] Loss: 1.7182
[Epoch 1, Batch 400] Loss: 1.6455
[Epoch 1, Batch 500] Loss: 1.5295
[Epoch 1, Batch 600] Loss: 1.4640
[Epoch 1, Batch 700] Loss: 1.4463
[Epoch 2, Batch 100] Loss: 1.3949
[Epoch 2, Batch 200] Loss: 1.3253
[Epoch 2, Batch 300] Loss: 1.3106
[Epoch 2, Batch 400] Loss: 1.2561
[Epoch 2, Batch 500] Loss: 1.2302
[Epoch 2, Batch 600] Loss: 1.2068
[Epoch 2, Batch 700] Loss: 1.1725
[Epoch 3, Batch 100] Loss: 1.1568
[Epoch 3, Batch 200] Loss: 1.1492
[Epoch 3, Batch 300] Loss: 1.1146
[Epoch 3, Batch 400] Loss: 1.0698
[Epoch 3, Batch 500] Loss: 1.0642
[Epoch 3, Batch 600] Loss: 1.0820
[Epoch 3, Batch 700] Loss: 1.0252
[Epoch 4, Batch 100] Loss: 1.0527
[Epoch 4, Batch 200] Loss: 0.9922
[Epoch 4, Batch 300] Loss: 0.9799
[Epoch 4, Batch 400] Loss: 0.9663
[Epoch 4, Batch 500] Loss: 0.9612
[Epoch 4, Batch 600] Loss: 0.9494
[Epoch 4, Batch 700] Loss: 0.9500
[Epoch 5, Batch 100] Loss: 0.8875
[Epoch 5, Batc

0,1
epoch,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▃▃▃▄▅▅▅▅▆▆▆▆▆▆▆▆▆▆▇▇▇██
loss,██▇▆▆▅▅▅▄▄▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▂▂▁▁▁▁▁
test_accuracy,▁

0,1
epoch,10.0
loss,0.6514
test_accuracy,78.75
