<a href="https://colab.research.google.com/github/protagora/learnable-activation-function/blob/dev/batchnorm_using_cupy_histogram.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader

import cupy as xp  # Import CuPy as xp

# try:
#     import cupy as xp  # Import CuPy as xp
# except ImportError:
#     import numpy as xp  # Fall back to NumPy if CuPy isn't available

# Custom Batch Normalization
class CustomHistogramBatchNorm(nn.Module):
    def __init__(self, num_features, bins=100, dim=2, eps=1e-5, momentum=0.1):
        super(CustomHistogramBatchNorm, self).__init__()
        self.eps = eps
        self.momentum = momentum
        self.dim = dim
        self.bins = bins  # Number of bins for the histogram

        # Learnable parameters for scaling and shifting
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.beta = nn.Parameter(torch.zeros(num_features))

        # Running statistics for inference
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))

    # # Define a function that computes the histogram
    # def compute_histogram(values, bins=self.bins, range=None, density=False):
    #     hist, bin_edges = xp.histogram(values, bins=bins, range=range, density=density)
    #     return hist, bin_edges


    def forward(self, x):
        if self.training:
            # Calculate mean and variance along the batch dimension
            if self.dim == 2:  # 2D for conv layers
                batch_mean = x.mean([0, 2, 3])
                batch_var = x.var([0, 2, 3], unbiased=False)
                x_flat = x.permute(1, 0, 2, 3).reshape(x.size(1), -1)
            else:  # 1D for fully connected layers
                batch_mean = x.mean(0)
                batch_var = x.var(0, unbiased=False)
                x_flat = x.t()

            # Calculate histogram for each feature/channel
            batch_mean_hist = torch.zeros(x.size(1), device=x.device)
            batch_var_hist = torch.zeros(x.size(1), device=x.device)
            for i in range(x.size(1)):
                # Use .item() to get the range as a tuple of floats
                min_val = x_flat[i].min().item()
                max_val = x_flat[i].max().item()

                # Call histogram with the corrected range argument
                # hist, bin_edges = torch.histogram(x_flat[i,], bins=self.bins, range=(min_val, max_val), density=False) # cuda issues

                # workaround
                if torch.cuda.is_available():
                  x_array = xp.asarray(x_flat[i].detach().cpu().numpy())
                else:
                  x_array = x_flat[i].cpu().numpy()

                hist, bin_edges = xp.histogram(x_array, bins=self.bins, range=(min_val, max_val), density=True)

                # Estimate mean and variance from histogram
                bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2

                batch_mean_temp = (hist * bin_centers).mean()
                batch_var_temp = (hist * (bin_centers - batch_mean_temp) ** 2).mean()

                batch_mean_hist[i] = torch.tensor(batch_mean_temp, device=x.device)
                batch_var_hist[i] = torch.tensor(batch_var_temp, device=x.device)


            # print("batch_mean")
            # print(batch_mean)
            # print(batch_mean_hist)

            # print("batch_var")
            # print(batch_var)
            # print(batch_var_hist)

            # Update running statistics
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean_hist
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * batch_var_hist

            mean, var = batch_mean_hist, batch_var_hist
        else:
            # Use running mean and variance during inference
            mean, var = self.running_mean, self.running_var

        # Normalize
        x_normalized = (x - mean[None, :, None, None]) / torch.sqrt(var[None, :, None, None] + self.eps) if self.dim == 2 else (x - mean) / torch.sqrt(var + self.eps)

        # Scale and shift
        return self.gamma[None, :, None, None] * x_normalized + self.beta[None, :, None, None] if self.dim == 2 else self.gamma * x_normalized + self.beta

# Define a CNN model with the custom batch normalization
class CNNWithCustomBatchNorm(nn.Module):
    def __init__(self):
        super(CNNWithCustomBatchNorm, self).__init__()

        # Convolutional layers with custom batch normalization
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1)
        self.bn1 = CustomHistogramBatchNorm(32, dim=2)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.bn2 = CustomHistogramBatchNorm(64, dim=2)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.bn3 = CustomHistogramBatchNorm(128, dim=2)

        # Pooling layer
        self.pool = nn.MaxPool2d(2, 2)  # Downsampling by 2

        # Placeholder for the fully connected layer; we'll determine in_features dynamically
        self.fc1 = None
        self.bn4 = None
        self.fc2 = nn.Linear(256, 10)  # CIFAR-10 has 10 classes

    def forward(self, x):
        # Convolutional layers with ReLU and custom batch normalization
        x = torch.relu(self.bn1(self.conv1(x)))
        x = self.pool(torch.relu(self.bn2(self.conv2(x))))
        x = self.pool(torch.relu(self.bn3(self.conv3(x))))

        # Flatten
        x = x.view(x.size(0), -1)  # Flatten

        # Initialize fully connected layer dynamically based on input size
        if self.fc1 is None:
            # Dynamically determine input size for fc1 based on current input dimensions
            self.fc1 = nn.Linear(x.size(1), 256).to(x.device)
            self.bn4 = CustomHistogramBatchNorm(256, dim=1).to(x.device)

        # Fully connected layers with ReLU and custom batch normalization
        x = torch.relu(self.bn4(self.fc1(x)))
        x = self.fc2(x)
        return x

# Check for CUDA
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load CIFAR-10 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize to [-1, 1]
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Initialize the CNN model, loss function, and optimizer
model = CNNWithCustomBatchNorm().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training function
def train(model, train_loader, criterion, optimizer, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        correct_train = 0
        total_train = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            # Calculate accuracy
            _, predicted = torch.max(outputs.data, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()

        # Calculate training accuracy
        train_accuracy = 100 * correct_train / total_train
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}, Training Accuracy: {train_accuracy:.2f}%')

# Evaluation function
def evaluate(model, test_loader):
    model.eval()
    correct_test = 0
    total_test = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total_test += labels.size(0)
            correct_test += (predicted == labels).sum().item()

    test_accuracy = 100 * correct_test / total_test
    print(f'Accuracy of the model on the test set: {test_accuracy:.2f}%')

# Train and evaluate the CNN model
train(model, train_loader, criterion, optimizer, num_epochs=10)
evaluate(model, test_loader)



Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:04<00:00, 38.3MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Epoch [1/10], Loss: 1.6845, Training Accuracy: 40.40%
Epoch [2/10], Loss: 1.3250, Training Accuracy: 52.87%
Epoch [3/10], Loss: 1.1815, Training Accuracy: 58.35%
Epoch [4/10], Loss: 1.0842, Training Accuracy: 61.69%
Epoch [5/10], Loss: 1.0066, Training Accuracy: 64.45%
Epoch [6/10], Loss: 0.9632, Training Accuracy: 66.09%
Epoch [7/10], Loss: 0.9108, Training Accuracy: 68.06%
Epoch [8/10], Loss: 0.8729, Training Accuracy: 69.29%
Epoch [9/10], Loss: 0.8378, Training Accuracy: 70.77%
Epoch [10/10], Loss: 0.8120, Training Accuracy: 71.52%
Accuracy of the model on the test set: 62.92%


In [None]:
import torch

try:
    import cupy as xp  # Import CuPy as xp
except ImportError:
    import numpy as xp  # Fall back to NumPy if CuPy isn't available

x_flat = torch.rand(1, 4, 2, 5)

bins = 10

def compute_histogram(values, bins=100, range=None, density=False):
        hist, bin_edges = xp.histogram(values, bins=bins, range=range, density=density)
        return hist, bin_edges

# Assuming x_flat is your flattened batch data for a channel
direct_mean = torch.mean(x_flat, axis=(0, 2, 3))
direct_var = torch.var(x_flat, axis=(0, 2, 3))

# Histogram-based mean and variance
# Compute histogram using CuPy (GPU-based) or NumPy (if CUDA is not available)
# hist, bin_edges = xp.histogram(x_flat, bins=10, range=(x_flat.min(), x_flat.max()), density=True) ##!!! problem all axis together
# Move the desired axis (axis=1) to the front

if torch.cuda.is_available():
  x_array = xp.asarray(x_flat.cpu().numpy())
else:
  x_array = x_flat.cpu().numpy()

reordered_arr = xp.moveaxis(x_array, 1, 0)  # Shape is now (#1, #0, #2, #3)

# Reshape to merge all but the first axis
merged_arr = reordered_arr.reshape(x_array.shape[1], -1)

# results = xp.apply_along_axis(compute_histogram, axis=2, arr=x_flat, bins=100, range=(x_flat.min(), x_flat.max()), density=True)

# Initialize arrays to store the hist_mean and hist_var

hist_mean = xp.zeros(merged_arr.shape[0])
hist_var = xp.zeros(merged_arr.shape[0])

# Compute the histogram for each slice along axis #1 not #0
for i in range(merged_arr.shape[0]):
        hist, bin_edges = xp.histogram(merged_arr[i], bins=bins, range=(merged_arr[i].min(), merged_arr[i].max()), density=True)

        bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
        hist_mean[i] = (hist * bin_centers).mean()
        hist_var[i] = (hist * (bin_centers - hist_mean[i]) ** 2).mean()


print("Direct Mean:", direct_mean, "Histogram Mean:", hist_mean)
print("Direct Variance:", direct_var, "Histogram Variance:", hist_var)

Direct Mean: tensor([0.6813, 0.3254, 0.3623, 0.4885]) Histogram Mean: [0.68887562 0.33697416 0.37751219 0.63782971]
Direct Variance: tensor([0.0944, 0.0847, 0.0823, 0.0843]) Histogram Variance: [0.07751192 0.07046054 0.06593457 0.12136811]
