In [1]:
import os
import numpy as np
import pandas as pd
import torch
import torchaudio
from sklearn.model_selection import train_test_split

In [2]:
input_dir = 'all_data/files'
df_metadata = pd.DataFrame(pd.read_csv('all_data/metadata/spikerbox_recordings.csv'))
output_dir = '.'
piece_length = 0.96

os.makedirs(os.path.join(output_dir, 'train_data/files'), exist_ok=True)
os.makedirs(os.path.join(output_dir, 'train_data/metadata'), exist_ok=True)
os.makedirs(os.path.join(output_dir, 'valid_data/files'), exist_ok=True)
os.makedirs(os.path.join(output_dir, 'valid_data/metadata'), exist_ok=True)
os.makedirs(os.path.join(output_dir, 'test_data/files'), exist_ok=True)
os.makedirs(os.path.join(output_dir, 'test_data/metadata'), exist_ok=True)

In [3]:
df_control_group = df_metadata[df_metadata.treatment == 0]
df_treatment_group = df_metadata[df_metadata.treatment == 1]
df_control_group_train, df_control_group_temp = train_test_split(
    df_control_group, test_size = 0.3, random_state = 2024
)
df_treatment_group_train, df_treatment_group_temp = train_test_split(
    df_treatment_group, test_size = 0.3, random_state = 2024
)
df_control_group_valid, df_control_group_test = train_test_split(
    df_control_group_temp, test_size = 1/3, random_state = 2024
)
df_treatment_group_valid, df_treatment_group_test = train_test_split(
    df_treatment_group_temp, test_size = 1/3, random_state = 2024
)
df_metadata_train = pd.concat([df_control_group_train, df_treatment_group_train], ignore_index = True)
df_metadata_valid = pd.concat([df_control_group_valid, df_treatment_group_valid], ignore_index = True)
df_metadata_test = pd.concat([df_control_group_test, df_treatment_group_test], ignore_index = True)
df_metadata_train.reset_index(inplace = True, drop = True)
df_metadata_valid.reset_index(inplace = True, drop = True)
df_metadata_test.reset_index(inplace = True, drop = True)

In [4]:
def split_wav_file(input_dir, output_dir, piece_length, file_name, df_metadata):
    performance = df_metadata.loc[df_metadata[df_metadata['filename'] == file_name].index[0], 'performance']
    treatment = df_metadata.loc[df_metadata[df_metadata['filename'] == file_name].index[0], 'treatment']
    stresslevel = df_metadata.loc[df_metadata[df_metadata['filename'] == file_name].index[0], 'stresslevel']
    df_metadata.drop(df_metadata[df_metadata['filename'] == file_name].index, inplace = True)
    waveform, sample_rate = torchaudio.load(file_path)
    samples_per_piece = int(piece_length * sample_rate)
    total_samples = waveform.size(1)
    num_pieces = (total_samples + samples_per_piece - 1) // samples_per_piece
    file_name = os.path.splitext(os.path.basename(input_dir))[0]
    for i in range(num_pieces):
        start_sample = i * samples_per_piece
        end_sample = min((i + 1) * samples_per_piece, total_samples)
        piece = waveform[:, start_sample:end_sample]
        piece_file_path = os.path.join(output_dir, f"{file_name}_{i+1}.wav")
        df_metadata = pd.concat([pd.DataFrame([
            [f"{file_name}_{i+1}.wav", performance, treatment, stresslevel]
        ], columns = df_metadata.columns), df_metadata], ignore_index=True)
        torchaudio.save(piece_file_path, piece, sample_rate)
    
    return df_metadata

In [5]:
for file_name in df_metadata_train.filename:
    file_path = os.path.join(input_dir, file_name)
    df_metadata_train = split_wav_file(
        file_path, os.path.join(output_dir, 'train_data/files'), piece_length, file_name, df_metadata_train
    )

In [6]:
for file_name in df_metadata_valid.filename:
    file_path = os.path.join(input_dir, file_name)
    df_metadata_valid = split_wav_file(
        file_path, os.path.join(output_dir, 'valid_data/files'), piece_length, file_name, df_metadata_valid
    )

In [7]:
for file_name in df_metadata_test.filename:
    file_path = os.path.join(input_dir, file_name)
    df_metadata_test = split_wav_file(
        file_path, os.path.join(output_dir, 'test_data/files'), piece_length, file_name, df_metadata_test
    )

In [8]:
df_metadata_train.to_csv(os.path.join(output_dir, 'train_data/metadata/file_labels.csv'), index = False)
df_metadata_valid.to_csv(os.path.join(output_dir, 'valid_data/metadata/file_labels.csv'), index = False)
df_metadata_test.to_csv(os.path.join(output_dir, 'test_data/metadata/file_labels.csv'), index = False)