# CIFAR-10データセットの読込

In [None]:
import pickle
import numpy as np
import os
import random

from IPython.core.display import display

%matplotlib inline
from matplotlib import pyplot as plt

In [None]:
# CIFAR-10データ（cifar-10-python.tar.gz）を解凍したディレクトリパスを指定する
# ex: CIFAR_DIR = '/home/user/cifar10/cifar-10-batches-py/'
CIFAR_DIR = 'path to cifar-10-batches-py'

In [None]:
# pickleで保存されたオブジェクトをロードする
def unpickle(file_name):
    with open(file_name, 'rb') as f:
        dataset = pickle.load(f, encoding='bytes')

    return dataset

In [None]:
# CIFAR-10オブジェクトデータを読み込み、(画像データ, 正解ラベルID)のタプルをリストに追記する
def GetCifar10(file_path, image_label_id_list):
    dataset = unpickle(file_path)
    
    image_labels = [(dataset[b'data'][i], dataset[b'labels'][i]) for i in range(len(dataset[b'data']))]
    image_label_id_list.extend(image_labels)

In [None]:
# 学習用データの読み込み
learn_image_label_id_list = []

for i in range(1, 6):
    file_name = 'data_batch_%d' % i
    file_path = os.path.join(CIFAR_DIR, file_name)
    
    GetCifar10(file_path, learn_image_label_id_list)

In [None]:
# 評価用データの読み込み
test_image_label_id_list = []

file_path = os.path.join(CIFAR_DIR, 'test_batch')
GetCifar10(file_path, test_image_label_id_list)

In [None]:
# ラベル名(airplane, etc)の取得
def GetCifar10LabelName(CIFAR_DIR):
    mata_data_path = os.path.join(CIFAR_DIR, 'batches.meta')
    meta_data = unpickle(mata_data_path)
    
    label_name_list = meta_data[b'label_names']
    return label_name_list

In [None]:
# ラベル名(airplane, etc)を取得する
label_name_list = GetCifar10LabelName(CIFAR_DIR)

In [None]:
# np.array形式のデータを画像に変換する
def ConvertToImage(data):
    return np.rollaxis(data.reshape((3, 32, 32)), 0, 3)

In [None]:
# 学習データからランダムに画像とラベルを選択し、表示する
image_label_id = random.choice(learn_image_label_id_list)

image = ConvertToImage(image_label_id[0])

plt.imshow(image)
plt.show()

image_label = label_name_list[image_label_id[1]]
print('Label: %s' % image_label.decode())