<a href="https://colab.research.google.com/github/protagora/learnable-activation-function/blob/dev/batchnorm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader

In [14]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader

# Custom Batch Normalization
class CustomBatchNorm(nn.Module):
    def __init__(self, num_features, dim=2, eps=1e-5, momentum=0.1):
        super(CustomBatchNorm, self).__init__()
        self.eps = eps
        self.momentum = momentum
        self.dim = dim
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.beta = nn.Parameter(torch.zeros(num_features))
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))

    def forward(self, x):
        if self.training:
            # Calculate mean and variance along the batch dimension
            if self.dim == 2:  # 2D for conv layers
                batch_mean = x.mean([0, 2, 3])
                batch_var = x.var([0, 2, 3], unbiased=False)
            else:  # 1D for fully connected layers
                batch_mean = x.mean(0)
                batch_var = x.var(0, unbiased=False)

            # Update running statistics
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * batch_var

            mean = batch_mean
            var = batch_var
        else:
            # Use running mean and variance during inference
            mean = self.running_mean
            var = self.running_var

        # Normalize
        x_normalized = (x - mean[None, :, None, None]) / torch.sqrt(var[None, :, None, None] + self.eps) if self.dim == 2 else (x - mean) / torch.sqrt(var + self.eps)

        # Scale and shift
        return self.gamma[None, :, None, None] * x_normalized + self.beta[None, :, None, None] if self.dim == 2 else self.gamma * x_normalized + self.beta

# Define a CNN model with the custom batch normalization
class CNNWithCustomBatchNorm(nn.Module):
    def __init__(self):
        super(CNNWithCustomBatchNorm, self).__init__()

        # Convolutional layers with custom batch normalization
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1)
        self.bn1 = CustomBatchNorm(32, dim=2)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.bn2 = CustomBatchNorm(64, dim=2)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.bn3 = CustomBatchNorm(128, dim=2)

        # Pooling layer
        self.pool = nn.MaxPool2d(2, 2)  # Downsampling by 2

        # Placeholder for the fully connected layer; we'll determine in_features dynamically
        self.fc1 = None
        self.bn4 = None
        self.fc2 = nn.Linear(256, 10)  # CIFAR-10 has 10 classes

    def forward(self, x):
        # Convolutional layers with ReLU and custom batch normalization
        x = torch.relu(self.bn1(self.conv1(x)))
        x = self.pool(torch.relu(self.bn2(self.conv2(x))))
        x = self.pool(torch.relu(self.bn3(self.conv3(x))))

        # Flatten
        x = x.view(x.size(0), -1)  # Flatten

        # Initialize fully connected layer dynamically based on input size
        if self.fc1 is None:
            # Dynamically determine input size for fc1 based on current input dimensions
            self.fc1 = nn.Linear(x.size(1), 256).to(x.device)
            self.bn4 = CustomBatchNorm(256, dim=1).to(x.device)

        # Fully connected layers with ReLU and custom batch normalization
        x = torch.relu(self.bn4(self.fc1(x)))
        x = self.fc2(x)
        return x

# Check for CUDA
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load CIFAR-10 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize to [-1, 1]
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Initialize the CNN model, loss function, and optimizer
model = CNNWithCustomBatchNorm().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training function
def train(model, train_loader, criterion, optimizer, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        correct_train = 0
        total_train = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            # Calculate accuracy
            _, predicted = torch.max(outputs.data, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()

        # Calculate training accuracy
        train_accuracy = 100 * correct_train / total_train
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}, Training Accuracy: {train_accuracy:.2f}%')

# Evaluation function
def evaluate(model, test_loader):
    model.eval()
    correct_test = 0
    total_test = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total_test += labels.size(0)
            correct_test += (predicted == labels).sum().item()

    test_accuracy = 100 * correct_test / total_test
    print(f'Accuracy of the model on the test set: {test_accuracy:.2f}%')

# Train and evaluate the CNN model
train(model, train_loader, criterion, optimizer, num_epochs=10)
evaluate(model, test_loader)


Files already downloaded and verified
Files already downloaded and verified
Epoch [1/10], Loss: 1.5326, Training Accuracy: 47.06%
Epoch [2/10], Loss: 1.1175, Training Accuracy: 61.57%
Epoch [3/10], Loss: 0.9486, Training Accuracy: 67.24%
Epoch [4/10], Loss: 0.8452, Training Accuracy: 71.08%
Epoch [5/10], Loss: 0.7642, Training Accuracy: 73.96%
Epoch [6/10], Loss: 0.7003, Training Accuracy: 76.49%
Epoch [7/10], Loss: 0.6440, Training Accuracy: 78.35%
Epoch [8/10], Loss: 0.5963, Training Accuracy: 79.94%
Epoch [9/10], Loss: 0.5526, Training Accuracy: 81.71%
Epoch [10/10], Loss: 0.5090, Training Accuracy: 83.47%
Accuracy of the model on the test set: 68.03%
