In [1]:
# This jupyter notebook is to preprocess data from raw data to splitted train, valid and test set

In [2]:
import pickle
import numpy as np
import pandas as pd
import re
import math
from utils import new_dir, read_pkl
from sklearn.model_selection import train_test_split
from collections import Counter

In [3]:
def convert_image(flatten_img):
    """
    convert image from numpy array of (1, 32*32*1) 
    into (32, 32, 3) which is H:32, W:32, RGB:3
    """
    # split the channels based on illustration in CIFAR-10 webpage
    # add a new channel axis
    red_chl = flatten_img[:1024].reshape(32,32)
    green_chl = flatten_img[1024:2048].reshape(32,32)
    blue_chl = flatten_img[2048:].reshape(32,32)
    
    # stack the image by the order RGB 
    # https://stackoverflow.com/questions/46898979/how-to-check-the-channel-order-of-an-image
    rgb_img = np.stack([red_chl, green_chl, blue_chl], axis=-1)

    return rgb_img

In [4]:
# step 1: read all data
train_valid_imgs = []
train_valid_labels = []
test_imgs = []
test_labels = []

In [5]:
# train valid 
for i in range(1, 6):
    # read in the batch data
    data_batch = read_pkl(f"./data/data_batch_{i}")
    for i in range(len(data_batch["labels"])):
        train_valid_imgs.append(convert_image(data_batch["data"][i, :]))
        train_valid_labels.append(data_batch["labels"][i])

In [6]:
len(train_valid_labels), len(train_valid_imgs)

(50000, 50000)

In [7]:
# test 
test_batch = read_pkl("./data/test_batch")
for i in range(len(test_batch["labels"])):
    test_imgs.append(convert_image(test_batch["data"][i, :]))
    test_labels.append(test_batch["labels"])

In [8]:
len(test_labels), len(test_imgs)

(10000, 10000)

In [9]:
# split train and validation set
X_train, X_valid, y_train, y_valid = train_test_split(train_valid_imgs, train_valid_labels, test_size=0.1, random_state=1234)

In [10]:
len(X_train), len(X_valid), len(y_train), len(y_valid)

(45000, 5000, 45000, 5000)

In [11]:
# zip them together
train_set = list(zip(X_train, y_train))
valid_set = list(zip(X_valid, y_valid))
test_set = list(zip(test_imgs, test_labels))

In [12]:
# write into files
new_dir("./processed_data")

with open("./processed_data/train_set.pkl", "wb") as fout:
    pickle.dump(train_set, fout)
with open("./processed_data/valid_set.pkl", "wb") as fout:
    pickle.dump(valid_set, fout)
with open("./processed_data/test_set.pkl", "wb") as fout:
    pickle.dump(test_set, fout)    

./processed_data created!


In [13]:
# take a look at the label distribution of the training set
Counter(y_train)

Counter({2: 4505,
         0: 4520,
         8: 4516,
         1: 4475,
         6: 4471,
         7: 4521,
         3: 4480,
         5: 4529,
         9: 4472,
         4: 4511})

In [14]:
# almost the same!