In [None]:
import torch
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10, CIFAR100, SVHN
import random
from torch.utils.data import DataLoader, Subset

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def get_mean_std(dataset=''):
    if dataset == 'cifar10':
        mean = (0.4914, 0.4822, 0.4465)
        std = (0.2023, 0.1994, 0.2010)
    elif dataset == 'cifar100':
        mean = (0.5071, 0.4865, 0.4409)
        std = (0.2009, 0.1984, 0.2023)
        
    return mean, std

In [None]:
def get_imageloader(dataset='', batch_size=0, mean=(0, 0, 0), std=(0, 0, 0)):   
    trainloader = None
    
    datapath = './data/' + dataset

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])

    # Load datasets with normalization
    if dataset == 'cifar10':
        trainset = CIFAR10(root=datapath, train=True, download=True, transform=transform)
        testset = CIFAR10(root=datapath, train=False, download=True, transform=transform)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False, num_workers=2)  
        testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
    elif dataset == 'cifar100':
        trainset = CIFAR100(root=datapath, train=True, download=True, transform=transform)
        testset = CIFAR100(root=datapath, train=False, download=True, transform=transform)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False, num_workers=2)  
        testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
    elif dataset == 'svhn':
        testset = SVHN(root=datapath, split='test', download=True, transform=transform)       
             
        num_samples = 10000
            
        print(f'num_samples: {num_samples}')

        # Set seed for reproducibility
        random.seed(42)

        indices = random.sample(range(len(testset)), num_samples)
        testset = Subset(testset, indices)

        # Create a DataLoader for the subset
        testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
    return trainloader, testloader