In [1]:
from studies.study_factory import StudyFactory
from dataloader.dataloader import DataLoader

dataloader = DataLoader(
    buffer_size=10,
    max_cache_size_gb=100,
    cache_dir="cache",
    notch_filter=True,
    frequency_bands={"all": (0.5, 100)},
    scaling="both",
    brain_clipping=20,
    baseline_window=0.5,
    new_freq=100,
    batch_types={"audio": 12},
    batch_kwargs={
        'audio': {
            'max_random_shift': 1,
            'window_size': 4,
            'window_stride': 1, 
            'audio_sample_rate': 16000,
            'hop_length': 160,
            'audio_processor': "openai/whisper-large-v3"
        }
    },
)

2024-12-10 13:14:42,348	INFO worker.py:1821 -- Started a local Ray instance.


In [2]:
study = StudyFactory.create_study(
    study_name='gwilliams2023',
    batch_type='audio',
    path='data/gwilliams2023',
    cache_enabled=True,
    max_cache_size=200, # in items
)

Loading GWilliams2023 with batch type audio


In [3]:
recordings = []

# Unfold all recordings (3 dim) of python list to 1
from itertools import chain
import random

flat_recordings = list(chain.from_iterable(chain.from_iterable(study.recordings)))
# random.shuffle(flat_recordings)

In [4]:
# # Start background fetching
import time


dataloader.start_fetching(flat_recordings, cache=True)

# Process batches as they become available
try:
    batches, recs, start_time = 0, 0, time.time()
    print(f'Total recordings: {len(flat_recordings)}')
    
    while True:
        batch = dataloader.get_recording()
        
        if batch is None:
            break
        
        brain = batch.brain_segments['all']
        batches += brain.shape[0]
        recs += 1
        
        if recs % 10 == 0:
            print(f"Batch {batches} ({recs} recordings) processed in {time.time() - start_time:.2f}s")
            print(
                f"Average processing time per recording: {(time.time() - start_time) / recs:.2f}s"
            )
            print(
                f"Average processing time per batch: {(time.time() - start_time) / batches:.2f}s"
            )
        
except KeyboardInterrupt:
    print("Interrupted")
    dataloader.stop()
except Exception as e:
    print("Error", e)
    dataloader.stop()

Total recordings: 196
Batch 3407 (10 recordings) processed in 7.49s
Average processing time per recording: 0.75s
Average processing time per batch: 0.00s
Batch 9724 (20 recordings) processed in 10.53s
Average processing time per recording: 0.53s
Average processing time per batch: 0.00s
Batch 15226 (30 recordings) processed in 12.85s
Average processing time per recording: 0.43s
Average processing time per batch: 0.00s
Batch 21909 (40 recordings) processed in 16.67s
Average processing time per recording: 0.42s
Average processing time per batch: 0.00s
Batch 27389 (50 recordings) processed in 19.50s
Average processing time per recording: 0.39s
Average processing time per batch: 0.00s
Batch 34056 (60 recordings) processed in 22.05s
Average processing time per recording: 0.37s
Average processing time per batch: 0.00s
Batch 39201 (70 recordings) processed in 24.52s
Average processing time per recording: 0.35s
Average processing time per batch: 0.00s
Batch 45892 (80 recordings) processed in 28

[36m(raylet)[0m Spilled 2136 MiB, 5 objects, write throughput 1150 MiB/s. Set RAY_verbose_spill_logs=0 to disable this message.
[36m(raylet)[0m Spilled 4799 MiB, 11 objects, write throughput 1436 MiB/s.
[36m(raylet)[0m Spilled 8776 MiB, 21 objects, write throughput 1613 MiB/s.
