# Cifar10 数据集

> 官网：http://www.cs.toronto.edu/~kriz/cifar.html

该数据集共有60000张彩色图像，这些图像是32*32，分为10个类，每类6000张图。
这里面有50000张用于训练，另外10000用于测试。

<img src="images/img-2021-07-25-10-38-37.png" width="30%">

一定要看图，这样才能知道哪些图片识别错了，原因是什么，这样才能进一步优化模型的性能

In [1]:
import os
import cv2
import numpy as np
import glob
from matplotlib import pyplot as plt

%matplotlib inline

In [2]:
# Cifar10 官网辅助函数
def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

In [4]:
label_name = ["airplane", "automobile", "bird", "cat", "deer",
            "dog", "frog", "horse", "ship", "truck"]

In [5]:
train_list = glob.glob("dataset/cifar-10/data_batch_*")
test_list = glob.glob("dataset/cifar-10/test_batch")

In [6]:
train_list

['dataset/cifar-10\\data_batch_1',
 'dataset/cifar-10\\data_batch_2',
 'dataset/cifar-10\\data_batch_3',
 'dataset/cifar-10\\data_batch_4',
 'dataset/cifar-10\\data_batch_5']

In [8]:
test_list

['dataset/cifar-10/test_batch']

In [7]:
sava_path = "dataset/cifar-10/train"
for l in train_list:
    l_dict = unpickle(l)
    # print(l_dict)
    print(l_dict.keys())

    for im_idx, im_data in enumerate(l_dict[b'data']):
        # print(im_idx)
        # print(im_data) # 图片数据是个向量，需要 reshape

        im_label = l_dict[b'labels'][im_idx] # 图片标签
        im_name = l_dict[b'filenames'][im_idx] # 图片数据

        # print(im_idx, im_label, im_name, im_data)
        im_label_name = label_name[im_label] # 标签的名字
        im_data = np.reshape(im_data, [3, 32, 32]) # 这个数据集通道在最前面
        # 把通道放到后面
        im_data = np.transpose(im_data, [1, 2, 0])
        # 通过 opencv 可视化
        # cv2.imshow("im_data", im_data)
        # cv2.waitKey(0)

        # 判断是否有相应的文件夹，没有则创建
        if not os.path.exists(f"{sava_path}/{im_label_name}"):
            os.mkdir(f"{sava_path}/{im_label_name}")
        
        # 写入图片
        cv2.imwrite(f"{sava_path}/{im_label_name}/{im_name.decode('utf-8')}", 
                        im_data)

dict_keys([b'batch_label', b'labels', b'data', b'filenames'])
dict_keys([b'batch_label', b'labels', b'data', b'filenames'])
dict_keys([b'batch_label', b'labels', b'data', b'filenames'])
dict_keys([b'batch_label', b'labels', b'data', b'filenames'])
dict_keys([b'batch_label', b'labels', b'data', b'filenames'])


In [9]:
sava_path = "dataset/cifar-10/test"
for l in test_list:
    l_dict = unpickle(l)
    # print(l_dict)
    print(l_dict.keys())

    for im_idx, im_data in enumerate(l_dict[b'data']):
        # print(im_idx)
        # print(im_data) # 图片数据是个向量，需要 reshape

        im_label = l_dict[b'labels'][im_idx] # 图片标签
        im_name = l_dict[b'filenames'][im_idx] # 图片数据

        # print(im_idx, im_label, im_name, im_data)
        im_label_name = label_name[im_label] # 标签的名字
        im_data = np.reshape(im_data, [3, 32, 32]) # 这个数据集通道在最前面
        # 把通道放到后面
        im_data = np.transpose(im_data, [1, 2, 0])
        # 通过 opencv 可视化
        cv2.imshow("im_data", im_data)
        cv2.waitKey(0)

        # 判断是否有相应的文件夹，没有则创建
        if not os.path.exists(f"{sava_path}/{im_label_name}"):
            os.mkdir(f"{sava_path}/{im_label_name}")
        
        # 写入图片
        cv2.imwrite(f"{sava_path}/{im_label_name}/{im_name.decode('utf-8')}", 
                        im_data)

dict_keys([b'batch_label', b'labels', b'data', b'filenames'])
