In [None]:
from functools import partial
import shutil
import typing as t
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from datasets import Dataset
from tqdm import tqdm
import matplotlib.pyplot as plt

import pyarrow as pa
from pyarrow.parquet import SortingColumn
from librosa import amplitude_to_db

from b2aiprep.prepare.bids import get_audio_paths

In [None]:
# make the demo dataset
demo_subjects = [
    '2b0751d5-10a9-46b8-9456-860e3a231588',
    '2b1c662a-59e0-4b69-9161-a2fd4636890b',
    '2c95136b-0efb-4d67-8fdf-113a8af5ac12',
    '2cb188fe-4e19-4da1-a1f4-1f9f254fd1fd',
    '2de86669-cc1f-45ef-9d31-a7eabd68f247',
    '2e59cc45-df19-4e2d-a1d8-60c46cf47a63',
    '04f6473c-99f1-4b65-bfe3-fcd9fa06dfc9',
    '0b8dcad5-a2a3-4373-bdca-7289b3cbadc7',
    '0e2df8b3-a93f-4982-a82c-d96a5c64d153',
    '0ee1e1e1-0e86-42cc-9e9d-2cafd9f1e01c',
    '1a7a86df-e379-40ab-a644-9821aac7be63',
    '1a7d3f7f-714c-4727-ab83-1b400fab2070',
    '1b07b18b-26f9-405b-a466-29442306a7fe',
    '1c2f13bb-909b-422d-ab61-e3b2ae203ed3',
    '1ce2db0c-505b-4bab-8cfd-04f1d5e08c4d',
    '1f9475bb-f13b-4f68-969b-28f20455b3e7',
    '1fcc920c-8038-4b75-8749-95bb4e4fff65',
    '2a86baba-6c34-4c34-9f7e-1d49d1605f6a',
    '2e71bcdb-b388-4901-8c7d-2a3afeb80fc1',
    '2f58e101-759a-45d4-85ff-0edf53638e52',
    '2f9f315d-70df-4a3e-992a-c224f6cca6c6',
    '3a744c68-53bc-438b-b5c8-7525150504d9',
    '3b7ace99-35f8-4245-a4ce-fab50358b2c4',
    '3bbc69ef-babd-499e-add2-1c9376ed62c6',
    '3e55cedf-26a8-41f8-9704-12160fa43305',
    '3ebb00fa-2e73-48b0-93e2-2ead5f573342',
    '4af9f960-2f33-4486-b4ff-3941f6440973',
    '4c55b481-e8a5-41fe-94c3-61d64f7a0ae2',
    '4ce6a1af-c123-41b7-a438-c58a41d1a3d7',
    '5c60197c-97dc-403d-bc93-4f1dbe9526f1',
]

original_path = Path.home().joinpath('data', 'bridge2ai', 'release-1.0')
data_path = Path.home().joinpath('data', 'bridge2ai', 'release-1.0-demo')

# first copy over folders for each subject
for subject in tqdm(demo_subjects, desc='Copying subject folders'):
    subject_folder = f'sub-{subject}'
    shutil.copytree(original_path.joinpath(subject_folder), data_path.joinpath(subject_folder), dirs_exist_ok=True)

# now load in each data file and filter
df = pd.read_csv(original_path.joinpath('participants.tsv'), sep='\t')
df = df[df['record_id'].isin(demo_subjects)]
df.to_csv(data_path.joinpath('participants.tsv'), sep='\t', index=False)

# all phenotype files in the phenotype subfolder
for data_filepath in original_path.joinpath('phenotype').glob('*.tsv'):
    df = pd.read_csv(data_filepath, sep='\t')
    df = df[df['record_id'].isin(demo_subjects)]
    df.to_csv(data_path.joinpath('phenotype', Path(data_filepath).name), sep='\t', index=False)

    # copy over corresponding json file
    json_filepath = data_filepath.with_suffix('.json')
    shutil.copy(json_filepath, data_path.joinpath('phenotype', json_filepath.name))

# copy over individual json files directly
for json_filename in ['dataset_description.json', 'participants.json']:
    shutil.copy(original_path.joinpath(json_filename), data_path.joinpath(json_filename))

In [None]:
bids_path = Path.home().joinpath('data', 'bridge2ai', 'release-1.0-demo')
output_base_path = Path.home().joinpath('data', 'bridge2ai', 'release-1.0-demo-processed')
audio_paths = get_audio_paths(bids_path)
# for debugging, subselect to those with a prefix of 0
audio_paths = [d for d in audio_paths if d['path'].name.startswith('sub-0')]
print(f'Found {len(audio_paths)} audio files')
audio_paths[:2]

In [None]:
# Option 1: Load all the data, save into a pytorch file
spectrograms = {}
for audio_info in audio_paths:
    pt_file = audio_info['path'].with_suffix('.pt')
    pt_file = pt_file.with_stem(pt_file.stem + '_features')
    data = torch.load(pt_file, weights_only=False)
    spectrograms[audio_info['path'].stem] = data['torchaudio']['spectrogram']

output_path = output_base_path.joinpath('option_1')
output_path.mkdir(parents=True, exist_ok=True)
torch.save(spectrograms, output_path / 'spectrograms.pt')

In [None]:
# Option 2: Save as parquet
import pyarrow.parquet as pq

def write_spectrograms(spec_dict, output_file):
    data = {
        'id': list(spec_dict.keys()),
        'shape': [[s.shape[1]] for s in spec_dict.values()],
        'data': [s.flatten().tolist() for s in spec_dict.values()]
    }
    
    table = pa.Table.from_pydict({
        'id': data['id'],
        'shape': data['shape'],
        'data': data['data']
    })
    
    pq.write_table(table, output_file, compression='zstd', compression_level=3)

def read_spectrograms(input_file):
    table = pq.read_table(input_file)
    spec_dict = {}
    
    for i in range(len(table)):
        id_ = table['id'][i].as_py()
        shape = table['shape'][i][0].as_py()
        data = np.array(table['data'][i].as_py()).reshape(513, shape)
        spec_dict[id_] = data
    
    return spec_dict

output_path = output_base_path.joinpath('option_2')
output_path.mkdir(parents=True, exist_ok=True)
write_spectrograms(spectrograms, output_path / 'spectrograms.parquet')

In [None]:
# Option 3: HF dataset

def audio_feature_generator(
    audio_paths,
) -> t.Generator[t.Dict[str, t.Any], None, None]:
    """Load audio features from individual files and yield dictionaries amenable to HuggingFace's Dataset from_generator."""
    for wav_path in tqdm(audio_paths, total=len(audio_paths), desc="Extracting features"):
        output = {}
        pt_file = wav_path.parent / f"{wav_path.stem}_features.pt"
        features = torch.load(pt_file, weights_only=False)
        output['spectrogram'] = features['torchaudio']['spectrogram'].numpy().astype(np.float32)

        yield output

audio_paths = [Path(d['path']) for d in audio_paths]
audio_feature_generator_partial = partial(audio_feature_generator, audio_paths=audio_paths)
ds = Dataset.from_generator(audio_feature_generator_partial, num_proc=1)

output_path = output_base_path.joinpath('option_3')
output_path.mkdir(parents=True, exist_ok=True)
ds.to_parquet(
    str(output_path / "spectrograms.parquet"),
    # pyarrow Parquet writer options
    compression="snappy",
)

In [None]:
# Option 4: Save each spectrogram as a separate file
output_path = output_base_path.joinpath('option_4')
output_path.mkdir(parents=True, exist_ok=True)
for spec in spectrograms:
    torch.save(spectrograms[spec], output_path.joinpath(f'{spec}.pt'))

In [None]:
# Option 5: Use a pandas dataframe intermediary and partitioning
# TODO: this doesn't work currently
# dataset = Dataset.from_pandas(df, preserve_index=True)
#   > ArrowInvalid: ('Can only convert 1-dimensional array values', 'Conversion failed for column array with type object')

def create_partitioned_dataset(
        spectrograms: t.Dict[str, np.ndarray],
        # tasks: t.List[str], 
        save_dir: str, partition_size: int = 1000
):
    # Extract metadata
    ids = list(spectrograms.keys())
    arrays = [x.numpy() for x in spectrograms.values()]
    df = pd.DataFrame({
        'id': ids,
        # 'task': tasks,
        'array_length': [arr.shape[1] for arr in arrays],
        'array': arrays
    })
    
    # Create multi-index for efficient filtering
    # df.set_index(['id', 'task'], inplace=True)
    df.set_index(['id'], inplace=True)
    
    
    # Convert to HF dataset
    dataset = Dataset.from_pandas(df, preserve_index=True)
    
    # Save with automatic partitioning
    dataset.save_to_disk(
        save_dir,
        num_shards=max(1, len(df) // partition_size),
        max_shard_size="500MB"
    )
    return dataset

def load_partitioned_dataset(save_dir: str):
    return Dataset.load_from_disk(save_dir)


output_path = output_base_path.joinpath('option_5')
output_path.mkdir(parents=True, exist_ok=True)
# dataset = create_partitioned_dataset(
#     spectrograms=spectrograms,
#     # tasks=tasks,
#     save_dir=str(output_path),
# )

In [None]:
# option 7: the above was bigger so we use the generator from option 3
def audio_feature_generator_sorted(
    audio_paths,
) -> t.Generator[t.Dict[str, t.Any], None, None]:
    """Load audio features from individual files and yield dictionaries amenable to HuggingFace's Dataset from_generator."""
    audio_paths = sorted(
        audio_paths,
        # sort first by subject, then by task
        key=lambda x: (x.stem.split('_')[0], x.stem.split('_')[2])
    )

    for wav_path in tqdm(audio_paths, total=len(audio_paths), desc="Extracting features"):
        output = {}
        pt_file = wav_path.parent / f"{wav_path.stem}_features.pt"
        features = torch.load(pt_file, weights_only=False)

        
        output['record_id'] = wav_path.stem.split('_')[0][4:] # skip "sub-" prefix
        output['task_name'] = wav_path.stem.split('_')[2][5:] # skip "task-" prefix
        output['spectrogram'] = features['torchaudio']['spectrogram'].numpy().astype(np.float32)

        yield output

audio_feature_generator_sorted_partial = partial(
    audio_feature_generator_sorted,
    audio_paths=audio_paths,
)

# sort the dataset by identifier and task_name
ds = Dataset.from_generator(audio_feature_generator_sorted_partial, num_proc=1)
sorting_columns = [
    SortingColumn(column_index=0, descending=False),
    SortingColumn(column_index=1, descending=False),
]

output_path = output_base_path.joinpath('option_7')
output_path.mkdir(parents=True, exist_ok=True)

# default options
ds.to_parquet(
    str(output_path / "spectrograms.parquet"),
    compression="snappy",
    use_dictionary=["record_id", "task_name"],
    write_statistics=True,
    data_page_size=1_048_576,
    write_page_index=True,
    sorting_columns=sorting_columns,
)

In [None]:
# Option 8: try to compress with zstd level 9
output_path = output_base_path.joinpath('option_8')
output_path.mkdir(parents=True, exist_ok=True)
ds.to_parquet(
    str(output_path / "spectrograms.parquet"),
    compression="zstd",
    compression_level=9,
    use_dictionary=["record_id", "task_name"],
    write_statistics=True,
    data_page_size=1_048_576,
    write_page_index=True,
    sorting_columns=sorting_columns,
)

In [None]:
# Option 9: try to compress with zstd level 3
output_path = output_base_path.joinpath('option_9')
output_path.mkdir(parents=True, exist_ok=True)

ds.to_parquet(
    str(output_path / "spectrograms.parquet"),
    compression="zstd",
    compression_level=3,
    use_dictionary=["record_id", "task_name"],
    write_statistics=True,
    data_page_size=1_048_576,
    write_page_index=True,
    sorting_columns=sorting_columns,
)

zstd > snappy.

what about pivoting the columns so that columnar compression is done for each frequency independently

In [None]:
# option 10: transpose the spectrograms for better column compression efficiency maybe?
def audio_feature_generator_sorted(
    audio_paths,
) -> t.Generator[t.Dict[str, t.Any], None, None]:
    """Load audio features from individual files and yield dictionaries amenable to HuggingFace's Dataset from_generator."""
    audio_paths = sorted(
        audio_paths,
        # sort first by subject, then by task
        key=lambda x: (x.stem.split('_')[0], x.stem.split('_')[2])
    )

    for wav_path in tqdm(audio_paths, total=len(audio_paths), desc="Extracting features"):
        output = {}
        pt_file = wav_path.parent / f"{wav_path.stem}_features.pt"
        features = torch.load(pt_file, weights_only=False)

        
        output['record_id'] = wav_path.stem.split('_')[0][4:] # skip "sub-" prefix
        output['task_name'] = wav_path.stem.split('_')[2][5:] # skip "task-" prefix
        # transpose the spectrogram
        output['spectrogram'] = features['torchaudio']['spectrogram'].numpy().astype(np.float32).T

        yield output

audio_feature_generator_sorted_partial = partial(
    audio_feature_generator_sorted,
    audio_paths=audio_paths,
)

# sort the dataset by identifier and task_name
ds = Dataset.from_generator(audio_feature_generator_sorted_partial, num_proc=1)

output_path = output_base_path.joinpath('option_10')
output_path.mkdir(parents=True, exist_ok=True)

ds.to_parquet(
    str(output_path / "spectrograms.parquet"),
    compression="zstd",  # Better compression ratio than snappy, still good speed
    compression_level=9,
    use_dictionary=["record_id", "task_name"],  # Enable dictionary encoding for strings
    write_statistics=True,
    data_page_size=1_048_576,  # 1MB pages
    write_page_index=True,  # Enable page index for better filtering
    sorting_columns=(
        SortingColumn(column_index=0, descending=False),
        SortingColumn(column_index=1, descending=False),
    ),
)

nope! it compresses better when the individual time samples are the columns.

OK, back to option 7, but let's try to convert the spectrogram to power, and add session_id

In [None]:
def audio_feature_power(
    audio_paths,
) -> t.Generator[t.Dict[str, t.Any], None, None]:
    """Load audio features from individual files and yield dictionaries amenable to HuggingFace's Dataset from_generator."""
    audio_paths = sorted(
        audio_paths,
        # sort first by subject, then by task
        key=lambda x: (x.stem.split('_')[0], x.stem.split('_')[2])
    )

    for wav_path in tqdm(audio_paths, total=len(audio_paths), desc="Extracting features"):
        output = {}
        pt_file = wav_path.parent / f"{wav_path.stem}_features.pt"
        features = torch.load(pt_file, weights_only=False)

        
        output['participant_id'] = wav_path.stem.split('_')[0][4:] # skip "sub-" prefix
        output['task_name'] = wav_path.stem.split('_')[2][5:] # skip "task-" prefix
        # transpose the spectrogram
        spectrogram = amplitude_to_db(features['torchaudio']['spectrogram'].numpy().astype(np.float32))
        output['spectrogram'] =  spectrogram

        yield output

audio_feature_power_partial = partial(
    audio_feature_power,
    audio_paths=audio_paths,
)

# sort the dataset by identifier and task_name
ds = Dataset.from_generator(
    audio_feature_power_partial, num_proc=1
)

output_path = output_base_path.joinpath('option_11')
output_path.mkdir(parents=True, exist_ok=True)

ds.to_parquet(
    str(output_path / "spectrograms.parquet"),
    compression="zstd",  # Better compression ratio than snappy, still good speed
    compression_level=3,
    use_dictionary=["participant_id", "task_name"],  # Enable dictionary encoding for strings
    write_statistics=True,
    data_page_size=1_048_576,  # 1MB pages
    write_page_index=True,  # Enable page index for better filtering
    sorting_columns=(
        SortingColumn(column_index=0, descending=False),
        SortingColumn(column_index=1, descending=False),
    ),
)

That worked well, we can add session_id

In [None]:
def audio_feature_power(
    audio_paths,
) -> t.Generator[t.Dict[str, t.Any], None, None]:
    """Load audio features from individual files and yield dictionaries amenable to HuggingFace's Dataset from_generator."""
    audio_paths = sorted(
        audio_paths,
        # sort first by subject, then by task
        key=lambda x: (x.stem.split('_')[0], x.stem.split('_')[2])
    )

    for wav_path in tqdm(audio_paths, total=len(audio_paths), desc="Extracting features"):
        output = {}
        pt_file = wav_path.parent / f"{wav_path.stem}_features.pt"
        features = torch.load(pt_file, weights_only=False)

        
        output['participant_id'] = wav_path.stem.split('_')[0][4:] # skip "sub-" prefix
        output['session_id'] = wav_path.stem.split('_')[1][4:] # skip "ses-" prefix
        output['task_name'] = wav_path.stem.split('_')[2][5:] # skip "task-" prefix
        # transpose the spectrogram
        spectrogram = amplitude_to_db(features['torchaudio']['spectrogram'].numpy().astype(np.float32))
        output['spectrogram'] =  spectrogram

        yield output

audio_feature_power_partial = partial(
    audio_feature_power,
    audio_paths=audio_paths,
)

# sort the dataset by identifier and task_name
ds = Dataset.from_generator(
    audio_feature_power_partial, num_proc=1
)

output_path = output_base_path.joinpath('option_12')
output_path.mkdir(parents=True, exist_ok=True)

ds.to_parquet(
    str(output_path / "spectrograms.parquet"),
    compression="zstd",  # Better compression ratio than snappy, still good speed
    compression_level=3,
    use_dictionary=["participant_id", "session_id", "task_name"],  # Enable dictionary encoding for strings
    write_statistics=True,
    data_page_size=1_048_576,  # 1MB pages
    write_page_index=True,  # Enable page index for better filtering
    sorting_columns=(
        SortingColumn(column_index=0, descending=False),
        SortingColumn(column_index=2, descending=False),
    ),
)

That compresses better. Try converting to decimal and quantizing to see if we can do some clever integer compression of decimals.

In [None]:
def audio_feature_power_decimal(audio_paths) -> t.Generator[t.Dict[str, t.Any], None, None]:
    audio_paths = sorted(audio_paths, key=lambda x: (x.stem.split('_')[0], x.stem.split('_')[2]))
    
    for wav_path in tqdm(audio_paths, total=len(audio_paths), desc="Extracting features"):
        output = {}
        pt_file = wav_path.parent / f"{wav_path.stem}_features.pt"
        features = torch.load(pt_file, weights_only=False)
        
        output['participant_id'] = wav_path.stem.split('_')[0][4:]
        output['session_id'] = wav_path.stem.split('_')[1][4:]
        output['task_name'] = wav_path.stem.split('_')[2][5:]
        
        # Process spectrogram
        spec = amplitude_to_db(features['torchaudio']['spectrogram'].numpy().astype(np.float32))
        n_row, n_col = spec.shape
        # Round to 4 decimal places
        spec_rounded = np.round(spec, decimals=4)
        # cast to decimal type
        spec_decimal = pa.array(spec_rounded.flatten(), type=pa.decimal128(8, 4))
        # Reshape back to original dimensions
        output['spectrogram'] = spec_decimal.reshape(n_row, n_col)
        
        yield output

ds = Dataset.from_generator(
    partial(
        audio_feature_power_decimal,
        audio_paths=audio_paths,
    ), num_proc=1)

output_path = output_base_path.joinpath('option_13')
output_path.mkdir(parents=True, exist_ok=True)
ds.to_parquet(
    str(output_path / "spectrograms.parquet"),
    compression="zstd",
    compression_level=3,
    use_dictionary=["participant_id", "session_id", "task_name"],
    write_statistics=True,
    data_page_size=1_048_576,
    write_page_index=True,
    sorting_columns=(
        SortingColumn(column_index=0, descending=False),
        SortingColumn(column_index=2, descending=False),
    ),
)

## Try with the full dataset

- Power spectrum in dB
- Not using decimal packing

In [None]:
full_dataset_audio_paths = get_audio_paths(Path.home().joinpath('data', 'bridge2ai', 'release-1.0'))
# for debugging, subselect to those with a prefix of 0
full_dataset_audio_paths = [d['path'] for d in full_dataset_audio_paths]
print(f'Found {len(full_dataset_audio_paths)} audio files.')
audio_feature_power_partial = partial(
    audio_feature_power,
    audio_paths=full_dataset_audio_paths,
)

# sort the dataset by identifier and task_name
ds = Dataset.from_generator(
    partial(
        audio_feature_power,
        audio_paths=full_dataset_audio_paths,
    ),
    num_proc=1
)

output_path = output_base_path.joinpath('option_12_full')
output_path.mkdir(parents=True, exist_ok=True)

ds.to_parquet(
    str(output_path / "spectrograms.parquet"),
    compression="zstd",  # Better compression ratio than snappy, still good speed
    compression_level=3,
    use_dictionary=["participant_id", "session_id", "task_name"],  # Enable dictionary encoding for strings
    write_statistics=True,
    data_page_size=1_048_576,  # 1MB pages
    write_page_index=True,  # Enable page index for better filtering
    sorting_columns=(
        SortingColumn(column_index=0, descending=False),
        SortingColumn(column_index=2, descending=False),
    ),
)

## Practice loading in the dataset

In [None]:
dataset = Dataset.from_parquet(
    output_base_path.joinpath('option_7', 'spectrograms.parquet').as_posix(),
)
print(dataset)

# plot a single spectrogram
idx = 18
print(dataset[idx]['record_id'])
print(dataset[idx]['task_name'])
spectrogram = np.asarray(dataset[idx]['spectrogram'])

# transform to decibel
spectrogram = amplitude_to_db(spectrogram)
plt.figure(figsize=(18, 5))
plt.imshow(spectrogram[:,:], aspect='auto', origin='lower')
# set x-ticks to be time in seconds, where each sample is 10ms
plt.xticks(
    np.linspace(0, spectrogram.shape[1], 11),
    np.linspace(0, spectrogram.shape[1] / 100, 11).astype(int)
)
plt.xlabel('Time (s)')
plt.colorbar()
plt.show()

In [None]:
dataset = Dataset.from_parquet(
    output_base_path.joinpath('option_12', 'spectrograms.parquet').as_posix(),
)
print(dataset)

idx = 18
print(dataset[idx]['record_id'])
print(dataset[idx]['task_name'])
spectrogram = np.asarray(dataset[idx]['spectrogram'])

plt.figure(figsize=(18, 5))
plt.imshow(spectrogram, aspect='auto', origin='lower')
# set x-ticks to be time in seconds, where each sample is 10ms
plt.xticks(
    np.linspace(0, spectrogram.shape[1], 11),
    np.linspace(0, spectrogram.shape[1] / 100, 11).astype(int)
)
plt.xlabel('Time (s)')
plt.colorbar()
plt.show()

In [None]:
import IPython.display as Ipd
import torchaudio
from librosa import db_to_amplitude
spectrogram = dataset[idx]['spectrogram']
# reconstruct the waveform
spectrogram = db_to_amplitude(np.asarray(spectrogram))
n_fft = 2 * (spectrogram.shape[0] - 1)
sr = 16000
win_length = int(sr * 25 / 1000)
hop_length = int(sr * 10 / 1000)
griffin_lim = torchaudio.transforms.GriffinLim(n_fft=n_fft, win_length=win_length, hop_length=hop_length, power=2)
reconstructed_waveform = griffin_lim(torch.tensor(spectrogram))
Ipd.display(Ipd.Audio(data=reconstructed_waveform, rate=sr))