In [None]:
import os
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import pickle
from tqdm import tqdm

source_dir = "raw_data"
out_dir = "data"

classes = [
    "airplane", "anvil", "apple", "axe", "banana", "baseball", "bee",
    "bicycle", "book", "boomerang", "butterfly", "cactus", "clock",
    "cloud", "crown", "donut", "duck", "envelope", "fish", "flower",
    "hourglass", "light bulb", "lightning", "mountain", "scissors",
    "shark", "skull", "smiley face", "star"
]

label_dict = {class_name: index for index, class_name in enumerate(classes)}

In [None]:
data_frames = []

for class_name in tqdm(classes):
    file_name = f"{class_name}.ndjson"
    file_path = os.path.join(source_dir, file_name)
    
    df = pd.read_json(file_path, lines=True)
    data_frames.append(df)
    
df = pd.concat(data_frames)

In [None]:
from image_processing import drawing_to_PIL, PIL_to_np

X = []
y = []

for _, row in tqdm(df.iterrows(), total=df.shape[0]):
    label = label_dict[row['word']]
    
    pil_img = drawing_to_PIL(row['drawing'])
    np_img = PIL_to_np(pil_img)
    
    X.append(np_img)
    y.append(label)

In [None]:
X = np.array(X)
X = np.expand_dims(X, axis=1)
X = X / 255.0

y = np.array(y)

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)

In [None]:
np.save(os.path.join(out_dir, 'X_train'), X_train)
np.save(os.path.join(out_dir, 'y_train'), y_train)
np.save(os.path.join(out_dir, 'X_test'), X_test)
np.save(os.path.join(out_dir, 'y_test'), y_test)

with open(os.path.join(out_dir, 'label_dict.pkl'), 'wb') as f:
    pickle.dump(label_dict, f)