In [1]:
import torch
import torchvision
import torchvision.transforms as transforms

In [2]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [1]:
import mxnet as mx
import numpy as np
import pickle
import cv2

def extractImagesAndLabels(file):
    f = open(file, 'rb')
    dict = pickle.load(f, encoding='latin1')
    images = dict['data']
    images = np.reshape(images, (10000, 3, 32, 32))
    labels = dict['labels']
    imagearray = mx.nd.array(images)
    labelarray = mx.nd.array(labels)
    return imagearray, labelarray

def extractCategories(file):
    f = open(file, 'rb')
    dict = pickle.load(f, encoding='latin1')
    return dict['label_names']

def saveCifarImage(array, path, file):
    # array is 3x32x32. cv2 needs 32x32x3
    array = array.asnumpy().transpose(1,2,0)
    # array is RGB. cv2 needs BGR
    array = cv2.cvtColor(array, cv2.COLOR_RGB2BGR)
    # save to PNG file
    return cv2.imwrite(path+file+".png", array)

In [2]:
classes = {}
imgarray, lblarray = extractImagesAndLabels("data/cifar-10-batches-py/data_batch_2")
#print(imgarray)
#print(lblarray)

In [3]:
categories = extractCategories("./data/cifar-10-batches-py/batches.meta")
print(categories)

['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']


In [4]:
for name in categories:
    classes[name] = []

for i in range(len(imgarray)):
    category = lblarray[i].asnumpy()
    category = (int)(category[0])
    category_name = categories[category]
    classes[category_name].append(imgarray[i])

In [5]:
print(classes.keys())
total = 0
for name in classes.keys():
    length = len(classes[name])
    print(length)
    total += length
print(total)

dict_keys(['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'])
984
1007
1010
995
1010
988
1008
1026
987
985
10000


In [7]:
path = "cifar10-classes-batch2/"
for name in classes.keys():
    folder = path+name+"/"
    print(folder)
    for i in range(len(classes[name])):
        saveCifarImage(classes[name][i], folder, (str)(i))

cifar10-classes-batch2/airplane/
cifar10-classes-batch2/automobile/
cifar10-classes-batch2/bird/
cifar10-classes-batch2/cat/
cifar10-classes-batch2/deer/
cifar10-classes-batch2/dog/
cifar10-classes-batch2/frog/
cifar10-classes-batch2/horse/
cifar10-classes-batch2/ship/
cifar10-classes-batch2/truck/
