In [None]:
from studies.study_factory import StudyFactory
from dataloader.dataloader import DataLoader

dataloader = DataLoader(
    buffer_size=10,
    max_cache_size_gb=100,
    cache_dir="cache",
    notch_filter=True,
    frequency_bands={"all": (0.5, 100)},
    scaling="both",
    brain_clipping=20,
    baseline_window=0.5,
    new_freq=100,
    batch_types={"audio": 12},
    batch_kwargs={
        'audio': {
            'max_random_shift': 1,
            'window_size': 4,
            'window_stride': 1, 
            'audio_sample_rate': 16000,
            'hop_length': 160,
            'audio_processor': "openai/whisper-large-v3"
        }
    },
)

In [12]:
study = StudyFactory.create_study(
    study_name='gwilliams2023',
    batch_type='audio',
    path='data/gwilliams2023',
    cache_enabled=True,
    max_cache_size=200, # in items
)

Loading Gwilliams2023 with batch type audio


In [None]:
from itertools import chain
import random

flat_recordings = list(chain.from_iterable(chain.from_iterable(study.recordings)))

In [None]:
# # Start background fetching
import time


dataloader.start_fetching([flat_recordings[0]], cache=True)

# Process batches as they become available
try:
    batches, recs, start_time = 0, 0, time.time()
    print(f'Total recordings: {len(flat_recordings)}')
    
    while True:
        batch = dataloader.get_recording()
        
        if batch is None:
            break
        
        brain = batch.brain_segments['all']
        batches += brain.shape[0]
        recs += 1
        
        if recs % 10 == 0:
            print(f"Batch {batches} ({recs} recordings) processed in {time.time() - start_time:.2f}s")
            print(
                f"Average processing time per recording: {(time.time() - start_time) / recs:.2f}s"
            )
            print(
                f"Average processing time per batch: {(time.time() - start_time) / batches:.2f}s"
            )
        
except KeyboardInterrupt:
    print("Interrupted")
    dataloader.stop()
except Exception as e:
    print("Error", e)
    dataloader.stop()

Total recordings: 196
Error invalid index to scalar variable.
