# Setup

In [1]:
import os
import pandas as pd

In [2]:
# dataset root directory:
DATA_DIR = "/proj/systewar/datasets/IMAC/image_dataset"
# label subdirectories:
SUBDIR_NAMES = ["excitement", "anger", "fear", "amusement", "awe", "contentment", "disgust", "sadness"]

In [3]:
# script options:
N_class_val = 50     # number of samples per class in validation set
N_class_test = 200     # number of samples per class in test set
metadata_files_split = {
    "train": "/proj/systewar/datasets/IMAC/image_dataset/metadata_train.csv",
    "val": "/proj/systewar/datasets/IMAC/image_dataset/metadata_val.csv",
    "test": "/proj/systewar/datasets/IMAC/image_dataset/metadata_test.csv"
}

# Get Image Files

In [4]:
# get dataset label directories:
subdir_names = [name for name in os.listdir(DATA_DIR) if os.path.isdir(os.path.join(DATA_DIR, name))]
subdir_names = [name for name in subdir_names if name in SUBDIR_NAMES]
assert subdir_names == SUBDIR_NAMES, "Error with listing dataset subdirectories."

In [5]:
# get all image file paths:
subdirs = []
image_file_names = []
emotion_labels = []
for label in subdir_names:
    # get file_names:
    subdir_path = os.path.join(DATA_DIR, label)
    file_names = [name for name in os.listdir(subdir_path) if os.path.isfile(os.path.join(subdir_path, name))]
    n_images = len(file_names)
    # save metadata:
    subdirs += n_images * [label]
    image_file_names += file_names
    emotion_labels += n_images * [label]

# create dataframe:
metadata = pd.DataFrame(
    data={
        "subdir_name": subdirs,
        "file_name": image_file_names,
        "label": emotion_labels
    }
)

print("Total size of dataset: {}".format(metadata.shape[0]))
print()
print(metadata.info())

Total size of dataset: 21829

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 21829 entries, 0 to 21828
Data columns (total 3 columns):
 #   Column       Non-Null Count  Dtype 
---  ------       --------------  ----- 
 0   subdir_name  21829 non-null  object
 1   file_name    21829 non-null  object
 2   label        21829 non-null  object
dtypes: object(3)
memory usage: 511.7+ KB
None


In [6]:
# get label counts:
label_counts = metadata["label"].value_counts()
for label in metadata["label"].value_counts().index:
    print("Number of {} images: {}".format(label, label_counts[label]))

Number of contentment images: 5130
Number of amusement images: 4724
Number of awe images: 2881
Number of excitement images: 2725
Number of sadness images: 2633
Number of disgust images: 1591
Number of anger images: 1176
Number of fear images: 969


# Split Dataset

In [7]:
# construct stratified test set by randomly sampling from each class:
metadata_test_groups = metadata.groupby(by="label", axis="index")
metadata_test = metadata_test_groups.sample(n=N_class_test, random_state=42)
print(metadata_test.info())

# sanity check:
for count in metadata_test["label"].value_counts():
    assert count == N_class_test, "Error with creating stratified test set."

<class 'pandas.core.frame.DataFrame'>
Int64Index: 1600 entries, 5097 to 20417
Data columns (total 3 columns):
 #   Column       Non-Null Count  Dtype 
---  ------       --------------  ----- 
 0   subdir_name  1600 non-null   object
 1   file_name    1600 non-null   object
 2   label        1600 non-null   object
dtypes: object(3)
memory usage: 50.0+ KB
None


In [8]:
# remove test set:
metadata_train_val = metadata.drop(index=list(metadata_test.index))

# construct stratified validation set by randomly sampling from each class:
metadata_val_groups = metadata_train_val.groupby(by="label", axis="index")
metadata_val = metadata_val_groups.sample(n=N_class_val, random_state=42)
print(metadata_val.info())

# sanity check:
for count in metadata_val["label"].value_counts():
    assert count == N_class_val, "Error with creating stratified validation set."

<class 'pandas.core.frame.DataFrame'>
Int64Index: 400 entries, 7392 to 19945
Data columns (total 3 columns):
 #   Column       Non-Null Count  Dtype 
---  ------       --------------  ----- 
 0   subdir_name  400 non-null    object
 1   file_name    400 non-null    object
 2   label        400 non-null    object
dtypes: object(3)
memory usage: 12.5+ KB
None


In [9]:
# construct training set:
metadata_train = metadata_train_val.drop(index=list(metadata_val.index))
print(metadata_train.info())
print()
print(metadata_train["label"].value_counts())

<class 'pandas.core.frame.DataFrame'>
Int64Index: 19829 entries, 0 to 21828
Data columns (total 3 columns):
 #   Column       Non-Null Count  Dtype 
---  ------       --------------  ----- 
 0   subdir_name  19829 non-null  object
 1   file_name    19829 non-null  object
 2   label        19829 non-null  object
dtypes: object(3)
memory usage: 619.7+ KB
None

contentment    4880
amusement      4474
awe            2631
excitement     2475
sadness        2383
disgust        1341
anger           926
fear            719
Name: label, dtype: int64


In [10]:
# check that all subsets are disjoint:
metadata_subsets = [metadata_train, metadata_val, metadata_test]
subset_names = list(metadata_files_split.keys())
for subset_1, name_1 in zip(metadata_subsets, subset_names):
    for subset_2, name_2 in zip(metadata_subsets, subset_names):
        if name_1 != name_2:
            assert set(subset_1.index).isdisjoint(set(subset_2.index)), "{} and {} are not disjoint".format(name_1, name_2)

In [11]:
# reset indices:
metadata_train = metadata_train.reset_index(drop=True)
metadata_val = metadata_val.reset_index(drop=True)
metadata_test = metadata_test.reset_index(drop=True)

# sanity checks:
assert metadata.shape[0] == metadata_train.shape[0] + metadata_val.shape[0] + metadata_test.shape[0], "Subset set sizes don't add up."
# check that all subsets are disjoint:
metadata_subsets = [metadata_train, metadata_val, metadata_test]
subset_names = list(metadata_files_split.keys())
for subset_1, name_1 in zip(metadata_subsets, subset_names):
    for subset_2, name_2 in zip(metadata_subsets, subset_names):
        if name_1 != name_2:
            assert set(subset_1["file_name"].tolist()).isdisjoint(set(subset_2["file_name"].tolist())), "{} and {} are not disjoint".format(name_1, name_2)
# more sanity checks:
class_counts_all = metadata["label"].value_counts()
class_counts_train = metadata_train["label"].value_counts()
class_counts_val = metadata_val["label"].value_counts()
class_counts_test = metadata_test["label"].value_counts()
for class_label in metadata_test_groups.groups.keys():
    assert class_counts_all[class_label] == class_counts_train[class_label] + class_counts_val[class_label] + class_counts_test[class_label], "Error with splitting dataset."

# save to file:
metadata_train.to_csv(metadata_files_split["train"], index=False)
metadata_val.to_csv(metadata_files_split["val"], index=False)
metadata_test.to_csv(metadata_files_split["test"], index=False)