In [37]:
import numpy as np
import os
import PIL
import torch
import torchvision
import pickle

from PIL import Image
from torch.utils.data import Subset
from torchvision import datasets, transforms

def load_txt(path :str) -> list:
    return [line.rstrip('\n') for line in open(path)]

corruptions = load_txt('./corruptions.txt')


class CIFAR10C(datasets.VisionDataset):
    def __init__(self, root :str, name :str,
                 transform=None, target_transform=None):
        assert name in corruptions
        super(CIFAR10C, self).__init__(
            root, transform=transform,
            target_transform=target_transform
        )
        data_path = os.path.join(root, name + '.npy')
        target_path = os.path.join(root, 'labels.npy')
        
        self.data = np.load(data_path)
        self.targets = np.load(target_path)
        
    def __getitem__(self, index):
        img, targets = self.data[index], self.targets[index]
        img = Image.fromarray(img)
        
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            targets = self.target_transform(targets)
            
        return img, targets
    
    def __len__(self):
        return len(self.data)


def extract_subset(dataset, num_subset :int, random_subset :bool):
    if random_subset:
        random.seed(0)
        indices = random.sample(list(range(len(dataset))), num_subset)
    else:
        indices = [i for i in range(num_subset)]
    return Subset(dataset, indices)



In [44]:
corruptions= corruptions[1:]
print(corruptions)

['shot_noise', 'speckle_noise', 'impulse_noise', 'defocus_blur', 'gaussian_blur', 'motion_blur', 'zoom_blur', 'snow', 'fog', 'brightness', 'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression', 'spatter', 'saturate', 'frost']


In [39]:
def cifar10c_loader(name,level):
    MEAN = [0.49139968, 0.48215841, 0.44653091]
    STD  = [0.24703223, 0.24348513, 0.26158784]
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(MEAN, STD)])  
    path = f'data/pickles/cifar10c-{name}_{level}.pkl'
    ds= CIFAR10C(root= './data/cifar10c/CIFAR-10-C', name= name, transform= transform)
    x,y= ds.data, ds.targets
    x= torch.tensor(x)
    y= torch.tensor(y)
    x= (x.float()/255.).permute(0,3,1,2)
    if level == 5:
        x= x[-10000:]
        y= y[-10000:]
    with open(path, 'wb') as f:
            pickle.dump([x, y], f)
    print("Saved In: {c}".format(c= path))
    
    
    

In [45]:
for corruption in corruptions:
    cifar10c_loader(name= corruption,level=5)

Saved In: data/pickles/cifar10c-shot_noise_5.pkl
Saved In: data/pickles/cifar10c-speckle_noise_5.pkl
Saved In: data/pickles/cifar10c-impulse_noise_5.pkl
Saved In: data/pickles/cifar10c-defocus_blur_5.pkl
Saved In: data/pickles/cifar10c-gaussian_blur_5.pkl
Saved In: data/pickles/cifar10c-motion_blur_5.pkl
Saved In: data/pickles/cifar10c-zoom_blur_5.pkl
Saved In: data/pickles/cifar10c-snow_5.pkl
Saved In: data/pickles/cifar10c-fog_5.pkl
Saved In: data/pickles/cifar10c-brightness_5.pkl
Saved In: data/pickles/cifar10c-contrast_5.pkl
Saved In: data/pickles/cifar10c-elastic_transform_5.pkl
Saved In: data/pickles/cifar10c-pixelate_5.pkl
Saved In: data/pickles/cifar10c-jpeg_compression_5.pkl
Saved In: data/pickles/cifar10c-spatter_5.pkl
Saved In: data/pickles/cifar10c-saturate_5.pkl
Saved In: data/pickles/cifar10c-frost_5.pkl
