## Calculate `MEAN` and `STD` for WebFace and VGGFace2 Datasets

In [1]:
from tqdm import tqdm

import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.ToTensor()  # Converts image to [C, H, W] with values in [0, 1]
])

def calculate_statistics(dataset_path, batch_size=256):
    dataset = datasets.ImageFolder(dataset_path, transform=transform)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    mean = 0.0
    std = 0.0
    num_samples = 0

    for images, _ in tqdm(data_loader):
        batch_samples = images.size(0)  # Batch size
        images = images.view(batch_samples, images.size(1), -1)  # Flatten HxW
        mean += images.mean(2).sum(0)  # Mean across (H, W) for each channel
        std += images.std(2).sum(0)  # Std across (H, W) for each channel
        num_samples += batch_samples

    mean /= num_samples
    std /= num_samples

    print("Mean:", mean)
    print("Standard Deviation:", std)

### `mean` and `std` for WebFace dataset

In [2]:
webface = "data/train/webface_112x112"
calculate_statistics(webface, batch_size=512)

100%|██████████| 959/959 [05:51<00:00,  2.73it/s]

Mean: tensor([0.5203, 0.4045, 0.3465])
Standard Deviation: tensor([0.2417, 0.2076, 0.1948])





### `mean` and `std` for VggFace2 dataset

In [3]:
vggface = "data/train/vggface2_train_112x112"
calculate_statistics(vggface, batch_size=512)

100%|██████████| 6129/6129 [54:21<00:00,  1.88it/s]  


Mean: tensor([0.5334, 0.4158, 0.3601])
Standard Deviation: tensor([0.2467, 0.2135, 0.2010])
