This code is performing a dataset split, specifically for organizing MIDI song data into training, validation, and test sets. Here’s a breakdown of what each section of the code accomplishes:

1. **Data Loading and Initial Setup**:
   - The code loads the dataset from a CSV file (`metadata_by_song.csv`) and sets the `songID` column as the index. 
   - `create_split` function is defined to sample a fixed number of songs from each quality category (`DominantQ`).

2. **Splitting into Test and Train/Validation Sets**:
   - `create_split` is used to sample `song_num` songs (8, by default) randomly from each quality category (Q1, Q2, Q3, Q4).
   - The result (`test_data`) is set aside as the test set, and the rest of the data (`train_val`) is retained for further splitting.

3. **Creating Validation and Training Sets**:
   - Another call to `create_split` generates the validation set (`val_data`) by sampling from `train_val`.
   - The remaining data in `train_val` becomes the training set (`train_data`).
   - The code ensures no overlap exists between the train, validation, and test sets, asserting their independence by checking indices.

4. **Counting and Verifying Ratios**:
   - The code calculates the total number of instances in each set for each quality category and verifies their relative ratios to assess the split distribution using `count_ratio`.

5. **Saving Song-Level Splits**:
   - Each split (train, validation, and test) is saved as a separate CSV file (`train_SL.csv`, `val_SL.csv`, `test_SL.csv`) for easy access.

6. **Linking MIDI Files to Their Split Category**:
   - The code scans the directory `data_root` (which contains the MIDI files) and links each MIDI file with its corresponding split (train, validation, or test) by comparing filenames to song IDs in `split_dict`.
   - Based on this matching, MIDI file names are appended to `train_clips`, `val_clips`, or `test_clips` lists.

7. **Saving Clip-Level Splits**:
   - Finally, it saves the lists of MIDI filenames for each split into separate CSV files (`train_clip.csv`, `val_clip.csv`, `test_clip.csv`) to keep track of individual MIDI clips associated with each split.

In summary, the code organizes and splits MIDI data into training, validation, and test datasets at both the song and clip levels, ensuring no overlap and creating a balanced distribution across quality categories. This setup would typically support machine learning tasks like classification or analysis on the dataset.

In [67]:
import pandas as pd

In [68]:
import glob
import os

In [69]:
src_csv = '../metadata_by_song.csv'
data = pd.read_csv(src_csv)
print(data.shape)
data = data.set_index("songID")
data.head()

FileNotFoundError: [Errno 2] No such file or directory: '../metadata_by_song.csv'

In [70]:
def create_split(data, song_num=8, random_seed=1):
    '''
    Will return test data as an new dataframe.
    Random sample song_num songs from each Q, 
    and the result of the number of clips will roughly be 1/10 in each Q.
    '''
    Q1 = data[data['DominantQ'] == 1]
    Q2 = data[data['DominantQ'] == 2]
    Q3 = data[data['DominantQ'] == 3]
    Q4 = data[data['DominantQ'] == 4]
    test_Q1 = Q1.sample(song_num, random_state=1)
    test_Q2 = Q2.sample(song_num, random_state=1)
    test_Q3 = Q3.sample(song_num, random_state=1)
    test_Q4 = Q4.sample(song_num, random_state=1)
    test_data = pd.concat([test_Q1, test_Q2, test_Q3, test_Q4])
    print(test_data.sum(axis = 0))
    return test_data

In [71]:
test_data = create_split(data)

num_Q1       25
num_Q2       19
num_Q3       21
num_Q4       23
DominantQ    80
dtype: int64


In [72]:
train_val = data.drop(labels = test_data.index, axis = 0, inplace = False)

In [73]:
train_val.head()

Unnamed: 0_level_0,num_Q1,num_Q2,num_Q3,num_Q4,DominantQ
songID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
e8NQ2NH0nc8,2,0,0,0,1
HQ8ISDX6PiI,0,0,0,2,4
ZTrEoB8T9YA,0,3,0,0,2
3N2G21U7guk,3,0,0,4,4
_8v0MFBZoco,2,0,0,0,1


In [74]:
val_data = create_split(train_val)

num_Q1       24
num_Q2       30
num_Q3       33
num_Q4       27
DominantQ    80
dtype: int64


In [75]:
train_data = train_val.drop(labels = val_data.index, axis = 0, inplace = False)

In [76]:
#make sure there is no song overlap between train, val, test
assert len(set(train_data.index) & set(val_data.index)) == 0
assert len(set(train_data.index) & set(test_data.index)) == 0
assert len(set(test_data.index) & set(val_data.index)) == 0

In [77]:
train_num = train_data.sum()
val_num = val_data.sum()
test_num = test_data.sum()

In [78]:
def count_ratio(nums, Q=1):
    total = sum([x['num_Q'+str(Q)] for x in nums])
    print('train: {}'.format(nums[0]['num_Q'+ str(Q)] / total))
    print('val  : {}'.format(nums[1]['num_Q'+ str(Q)] / total))
    print('test : {}'.format(nums[2]['num_Q'+ str(Q)] / total))
    

In [79]:
nums = [train_num, val_num, test_num]
count_ratio(nums, Q=1)


train: 0.804
val  : 0.096
test : 0.1


In [80]:
#save song level split
os.makedirs('../split', exist_ok=True)

train_data.to_csv('../split/train_SL.csv')
val_data.to_csv('../split/val_SL.csv')
test_data.to_csv('../split/test_SL.csv')

In [81]:
data_root = '../midis'

In [82]:
len(data_root)

8

In [83]:
midi_files = glob.glob(os.path.join(data_root, '*.mid'))

In [84]:
split_dict = {}
for song in train_data.index:
    split_dict[song] = 'train'
for song in val_data.index:
    split_dict[song] = 'val'
for song in test_data.index:
    split_dict[song] = 'test'

    

In [85]:
# formatted as: ../midis/Q3_egYSmNuIFGk_1.mid

train_clips = []
val_clips = []
test_clips = []

for mid in midi_files:
    filename = mid[9:]
    # print(filename)
    songname = filename[3:14]
    # print(songname)
    if split_dict[songname] == 'train':
        train_clips.append(filename)
    if split_dict[songname] == 'val':
        val_clips.append(filename)
    if split_dict[songname] == 'test':
        test_clips.append(filename)

In [86]:
#save the clip lists
train_df = pd.DataFrame({'clip_name': train_clips})
train_df.to_csv('../split/train_clip.csv')

val_df = pd.DataFrame({'clip_name': val_clips})
val_df.to_csv('../split/val_clip.csv')

test_df = pd.DataFrame({'clip_name': test_clips})
test_df.to_csv('../split/test_clip.csv')