# Setup

In [1]:
import pandas as pd

In [2]:
# data paths:
LABEL_FILES_ORIG_SPLIT = {
    "unbalanced_train": "/proj/systewar/datasets/audioset_music_mood/orig_split_label_files/labels_unbalanced_train.csv",
    "eval": "/proj/systewar/datasets/audioset_music_mood/orig_split_label_files/labels_eval.csv"
}

In [3]:
# script options:
N_class_test = 100     # number of samples per class in (new) test set
label_files_new_split = {
    "train": "/proj/systewar/datasets/audioset_music_mood/labels_train.csv",
    "test": "/proj/systewar/datasets/audioset_music_mood/labels_test.csv"
}

# Split Dataset

In [4]:
# load original split labels:
labels_orig_split = {}
for subset, file_path in LABEL_FILES_ORIG_SPLIT.items():
    print("Loading {} set labels...".format(subset))
    labels_orig_split[subset] = pd.read_csv(file_path)

# concatenate original split labels into a single dataframe:
all_labels = pd.concat(labels_orig_split.values(), axis="index")
all_labels = all_labels.reset_index(drop=True)
print()
print(all_labels.info())

Loading unbalanced_train set labels...
Loading eval set labels...

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 13395 entries, 0 to 13394
Data columns (total 3 columns):
 #   Column       Non-Null Count  Dtype 
---  ------       --------------  ----- 
 0   orig_subset  13395 non-null  object
 1   file_name    13395 non-null  object
 2   label        13395 non-null  object
dtypes: object(3)
memory usage: 314.1+ KB
None


In [5]:
# group by class label:
all_label_groups = all_labels.groupby(by="label", axis="index")

# construct stratified test set by randomly sampling from each class:
labels_test = all_label_groups.sample(n=N_class_test, random_state=42)
print(labels_test.info())

# sanity check:
for count in labels_test["label"].value_counts():
    assert count == N_class_test, "Error with creating stratified test set."

<class 'pandas.core.frame.DataFrame'>
Int64Index: 700 entries, 10980 to 3237
Data columns (total 3 columns):
 #   Column       Non-Null Count  Dtype 
---  ------       --------------  ----- 
 0   orig_subset  700 non-null    object
 1   file_name    700 non-null    object
 2   label        700 non-null    object
dtypes: object(3)
memory usage: 21.9+ KB
None


In [6]:
# construct training set:
labels_train = all_labels.drop(index=list(labels_test.index))
print(labels_train.info())
print()
print(labels_train["label"].value_counts())

# sanity check:
assert set(labels_train.index).isdisjoint(set(labels_test.index)), "Train and test sets are not disjoint."

<class 'pandas.core.frame.DataFrame'>
Int64Index: 12695 entries, 0 to 13394
Data columns (total 3 columns):
 #   Column       Non-Null Count  Dtype 
---  ------       --------------  ----- 
 0   orig_subset  12695 non-null  object
 1   file_name    12695 non-null  object
 2   label        12695 non-null  object
dtypes: object(3)
memory usage: 396.7+ KB
None

Exciting music    4431
Tender music      3214
Scary music       1253
Sad music         1242
Happy music       1006
Angry music        794
Funny music        755
Name: label, dtype: int64


In [7]:
# reset indices:
labels_train = labels_train.reset_index(drop=True)
labels_test = labels_test.reset_index(drop=True)

# more sanity checks:
assert all_labels.shape[0] == labels_train.shape[0] + labels_test.shape[0], "Train and test set sizes don't add up."
assert set(labels_train["file_name"].tolist()).isdisjoint(set(labels_test["file_name"].tolist())), "Train and test sets are not disjoint."
class_counts_all = all_labels["label"].value_counts()
class_counts_train = labels_train["label"].value_counts()
class_counts_test = labels_test["label"].value_counts()
for class_label in all_label_groups.groups.keys():
    assert class_counts_all[class_label] == class_counts_train[class_label] + class_counts_test[class_label], "Error with splitting dataset."

# save to files:
labels_train.to_csv(label_files_new_split["train"], index=False)
labels_test.to_csv(label_files_new_split["test"], index=False)