In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.adam import Adam
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader, Subset
from torchvision import transforms, datasets
from tqdm.notebook import tqdm
from utils import plot_full_evaluation, create_activation_image


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using gpu: %s " % torch.cuda.is_available())

In [None]:
datasets_function = datasets.MNIST
# datasets_function = datasets.FashionMNIST
# datasets_function = datasets.CIFAR10

In [None]:
# Define the transform to scale images to [0, 1]
transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((224, 224))])

try:  # Load the dataset with the transform
    train_dataset = datasets_function(
        root="./data", train=True, transform=transform, download=True
    )
    test_dataset = datasets_function(
        root="./data", train=False, transform=transform, download=True
    )
except TypeError:
    train_dataset = datasets_function(
        root="./data", image_set="train", transform=transform, download=True
    )
    test_dataset = datasets_function(
        root="./data", image_set="val", transform=transform, download=True
    )

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Verify the scaling
for data, _ in train_loader:
    print("Min pixel value:", data.min().item())
    print("Max pixel value:", data.max().item())
    break  # Check only the first batch

In [None]:
# # Compute mean and standard deviation
# transform = transforms.Compose([transforms.ToTensor()])
# train_dataset = datasets_function(
#     root="./data", train=True, transform=transform, download=True
# )
# train_loader = DataLoader(train_dataset, batch_size=64, shuffle=False)

# sum_tensor = torch.zeros(1)
# sum_squared_tensor = torch.zeros(1)
# count = 0

# for data, _ in train_loader:
#     sum_tensor += data.sum()
#     sum_squared_tensor += (data**2).sum()
#     count += data.numel()

# mean = sum_tensor / count
# std = torch.sqrt(sum_squared_tensor / count - mean**2)

# print("Mean:", mean.item())
# print("Standard Deviation:", std.item())

# # Create normalization transform
# normalize_transform = transforms.Compose(
#     [
#         transforms.ToTensor(),
#         transforms.Normalize(
#             mean=[mean.item()], std=[std.item()]
#         ),  # Normalize to mean 0 and std 1
#         transforms.Normalize(mean=[-1], std=[2]),  # Scale to mean 0.5 and std 0.5
#     ]
# )

# # Apply the Normalization Transform to the Dataset
# train_dataset = datasets_function(
#     root="./data", train=True, transform=normalize_transform, download=True
# )
# test_dataset = datasets_function(
#     root="./data", train=False, transform=normalize_transform, download=True
# )

# # Verify the transform
# train_loader = DataLoader(train_dataset, batch_size=64, shuffle=False)
# for data, _ in train_loader:
#     print("Transformed Mean:", data.mean().item())
#     print("Transformed Std:", data.std().item())
#     break  # Check only the first batch

In [None]:
proportion = 0.5

class_names = train_dataset.classes
num_classes = len(class_names)
print(f"{class_names = }")
print(f"{num_classes = }")

# Calculate the number of samples for the subsets
num_train_samples = int(len(train_dataset) * proportion)
num_test_samples = int(len(test_dataset) * proportion)

# Create subsets of the datasets
train_subset = Subset(train_dataset, range(num_train_samples))
test_subset = Subset(test_dataset, range(num_test_samples))

train_data_size = len(train_subset)
test_data_size = len(test_subset)
input_shape = train_subset[0][0].shape
input_size = np.prod(input_shape)
print(f"{train_data_size = }")
print(f"{test_data_size = }")
print(f"{input_shape = }")
print(f"{input_size = }")

In [None]:
def display_instances(dataset, class_names, num_instances=3):
    # Initialize a dictionary to store instances for each class
    instances = {class_name: [] for class_name in class_names}

    # Iterate through the dataset and collect instances
    for img, label in dataset:
        class_name = class_names[label]
        if len(instances[class_name]) < num_instances:
            instances[class_name].append(img)
        if all(
            len(instances[class_name]) == num_instances for class_name in class_names
        ):
            break

    # Create a grid of subplots
    fig, axes = plt.subplots(
        num_instances,
        len(class_names),
        figsize=(len(class_names) * 2, num_instances * 2),
    )

    # Plot the collected instances
    for i in range(num_instances):
        for j, class_name in enumerate(class_names):
            ax = axes[i, j]
            cmap = "gray" if input_shape[0] == 1 else None
            ax.imshow(instances[class_name][i].permute(1, 2, 0), cmap=cmap)
            ax.axis("off")
            if i == 0:
                ax.set_title(class_name)

    plt.tight_layout()
    plt.show()


display_instances(train_subset, class_names)

In [None]:
batch_size = 64  # Batch size for training

train_loader = DataLoader(dataset=train_subset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_subset, batch_size=batch_size, shuffle=False)

In [None]:
# Define a simple neural network model
class SimpleNN(nn.Module):
    def __init__(self, num_classes: int, input_shape: tuple[int, int, int]):
        super(SimpleNN, self).__init__()
        self.num_classes = num_classes
        self.input_channels = input_shape[0]
        self.input_width = input_shape[1]
        self.input_size = self.input_channels * (self.input_width) ** 2
        self.linear_input_width = self.input_width // 8
        self.layers = nn.ModuleList(
            [
                nn.Conv2d(input_shape[0], 32, kernel_size=3, stride=1, padding=1),
                nn.LeakyReLU(),
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
                nn.LeakyReLU(),
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
                nn.LeakyReLU(),
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Flatten(1, -1),
                # nn.Dropout(),
                # nn.Linear(64 * 7 * 7, 64),
                # nn.ReLU(),
                # nn.Dropout(),
                # nn.Linear(64, num_classes),
                nn.Linear(128 * (self.linear_input_width) ** 2, num_classes),
            ]
        )

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


model = SimpleNN(num_classes=len(class_names), input_shape=input_shape).to(device)
tensor = torch.rand(1, *input_shape).to(device)
model(tensor)
print(model)

In [None]:
# Training loop
def train_model(
    model: nn.Module,
    train_loader: DataLoader,
    num_epochs: int,
    criterion: nn.Module,
    optimizer: Optimizer,
):
    model.train()  # Set the model to training mode
    for epoch in range(num_epochs):
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(
            f"Epoch [{epoch+1:>{len(str(num_epochs))}}/{num_epochs}], Loss: {running_loss/len(train_loader):2.4f}"
        )


# Evaluation on the test set
def evaluate_model(model: nn.Module, test_loader: DataLoader, class_names: list[str]):
    model.eval()  # Set the model to evaluation mode
    y_true, y_pred = [], []

    with torch.no_grad():  # Disable gradient computation
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            y_true.extend(list(labels.cpu().numpy()))
            y_pred.extend(list(predicted.cpu().numpy()))

    plot_full_evaluation(np.array(y_true), np.array(y_pred), class_names)


# # 6. Plot some predictions
# def plot_predictions(model, test_loader, classes):
#     model.eval()
#     images, labels = next(iter(test_loader))
#     images, labels = images.to(device), labels.to(device)

#     outputs = model(images)
#     _, preds = torch.max(outputs, 1)

#     # Plot the first 6 test images and their predictions
#     plt.figure(figsize=(10, 6))
#     for i in range(6):
#         plt.subplot(2, 3, i + 1)
#         plt.imshow(images[i].cpu().numpy().squeeze(), cmap="gray")
#         plt.title(f"True: {classes[labels[i]]}, Pred: {classes[preds[i]]}")
#         plt.axis("off")
#     plt.tight_layout()
#     plt.show()


# # Plot predictions
# plot_predictions(model, test_loader, test_dataset.classes)

In [None]:
# Device configuration (Use GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
num_epochs = 10  # Number of epochs
learning_rate = 0.001  # Learning rate for optimizer

# Initialize the model, loss function, and optimizer
model = SimpleNN(num_classes=num_classes, input_shape=input_shape).to(device)
criterion = nn.CrossEntropyLoss()  # For multi-class classification
optimizer = Adam(model.parameters(), lr=learning_rate)  # Often used optimizer

# Train the model
train_model(model, train_loader, num_epochs, criterion, optimizer)

# Evaluate the model
evaluate_model(model, test_loader, class_names)

In [None]:
layers = [len(model.layers) - 1]
channels = list(range(num_classes))
channels_names = class_names
steps = 200
lr = 0.01
show_steps = True

activation_images_transform = None
# activation_images_transform = transforms.Compose(
#     [
# transforms.Lambda(lambda x: x + 0.001 * (2 * torch.rand_like(x) - 1)),
# transforms.RandomAffine(degrees=5, translate=(0.1, 0.1), scale=(0.9, 1.1)),
# transforms.RandomCrop(20),
# transforms.ElasticTransform(alpha=50.0),
#     ]
# )

# Compute the activation images
activation_images = np.zeros((len(layers), len(channels), *input_shape))
for i, layer in enumerate(tqdm(layers, position=0, desc="layer", colour="green")):
    for j, channel in enumerate(
        tqdm(channels, position=1, desc="channel", colour="red")
    ):
        activation_images[i, j] = create_activation_image(
            model=model,
            layer=layer,
            channel=channel,
            input_mean=0,
            input_std=1,
            steps=steps,
            lr=lr,
            show_steps=show_steps,
            transform=activation_images_transform,
            input_shape=input_shape,
            progress_bar=False,
        )

In [None]:
# Display the images
for i, layer in enumerate(layers):
    instances = len(channels)
    cols = int(np.ceil(np.sqrt(instances)))
    rows = int(np.ceil(instances / cols))
    plt.figure(figsize=(15, 15 * rows / cols))
    plt.suptitle(f"Layer {layer}")
    for j, (channel, channel_name) in enumerate(zip(channels, channels_names)):
        plt.subplot(rows, cols, j + 1)
        plt.imshow(activation_images[i, j].transpose(1, 2, 0).squeeze(), cmap="gray")
        plt.axis("off")
        plt.title(channel_name)
    plt.tight_layout()
    plt.show()