In [1]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Images Mean & Std By Channel

In [None]:
def calculate_mean_std_explain():
    # Simulate a batch of 4 random images, with 3 channels (RGB), and size 64x64 pixels
    batch_size = 4
    channels = 3
    height, width = 64, 64
    
    # Create a random tensor of shape (4, 3, 64, 64)
    images = torch.randn(batch_size, channels, height, width)
    
    # Step 1: Check the shape of the images
    print(f"Image Batch Shape: {images.shape}")
    
    # Step 2: Calculate the mean and std for each channel
    # Reshape to (batch_size, channels, H*W) so each image's pixels are flattened
    images_flat = images.view(batch_size, channels, -1)
    
    # Step 3: Calculate mean and std for each channel
    mean = images_flat.mean(dim=2).mean(dim=0)
    std = images_flat.std(dim=2).mean(dim=0)
    
    print(f"Mean: {mean}")
    print(f"Standard Deviation: {std}")

In [None]:
def calculate_mean_std(dataset_path, batch_size=32):
    """
    Calculate the mean and standard deviation for each channel (RGB) of the dataset.

    Args:
        dataset_path (str): Path to the dataset directory.
        batch_size (int): Batch size for DataLoader. Default is 32.

    Returns:
        (mean, std): Tuple containing mean and standard deviation for each channel (R, G, B).
    """
    # Define a transformation to convert images to tensor
    transform = transforms.Compose([transforms.ToTensor()])

    # Load the dataset
    dataset = datasets.ImageFolder(dataset_path, transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    # Initialize variables to accumulate pixel sums and squared sums
    num_images = 0
    mean = torch.zeros(3)  # tensor([0., 0., 0.])
    std = torch.zeros(3)  # tensor([0., 0., 0.])

    # Iterate over the dataset
    for images, _ in dataloader:
        batch_size, channels, height, width = images.shape
        num_images += batch_size

        # Reshape to (batch_size, 3, H*W) for easier mean/std calculation
        images = images.view(batch_size, channels, -1)

        # Calculate sum for mean and std numerators for each channel
        mean += images.mean(2).sum(0)
        std += images.std(2).sum(0)

    # Calculate final mean and std
    mean /= num_images
    std /= num_images

    print(f"Mean: {mean}")
    print(f"Standard Deviation: {std}")

    return mean, std