In [1]:
from studies.study_factory import StudyFactory
from dataloader.dataloader import DataLoader

dataloader = DataLoader(
    buffer_size=10,
    max_cache_size_gb=100,
    cache_dir="cache",
    notch_filter=True,
    frequency_bands={"all": (0.5, 100)},
    scaling="both",
    brain_clipping=20,
    baseline_window=0.5,
    new_freq=100,
    batch_types={"audio": 12},
    batch_kwargs={
        'audio': {
            'max_random_shift': 1,
            'window_size': 4,
            'window_stride': 1, 
            'audio_sample_rate': 16000,
            'hop_length': 160,
            'audio_processor': "openai/whisper-large-v3"
        }
    },
)

2024-12-12 03:49:39,746	INFO worker.py:1821 -- Started a local Ray instance.


In [2]:
study = StudyFactory.create_study(
    study_name='gwilliams2023',
    batch_type='audio',
    path='data/gwilliams2023',
    cache_enabled=True,
    max_cache_size=200, # in items
)

Loading Gwilliams2023 with batch type audio


In [3]:
from itertools import chain
import random

flat_recordings = list(chain.from_iterable(chain.from_iterable(study.recordings)))

In [4]:
# # Start background fetching
import time


dataloader.start_fetching(flat_recordings, cache=True)

# Process batches as they become available
try:
    batches, recs, start_time = 0, 0, time.time()
    print(f'Total recordings: {len(flat_recordings)}')
    
    while True:
        batch = dataloader.get_recording()
        
        if batch is None:
            break
        
        brain = batch.brain_segments['all']
        batches += brain.shape[0]
        recs += 1
        
        if recs % 10 == 0:
            print(f"Batch {batches} ({recs} recordings) processed in {time.time() - start_time:.2f}s")
            print(
                f"Average processing time per recording: {(time.time() - start_time) / recs:.2f}s"
            )
            print(
                f"Average processing time per batch: {(time.time() - start_time) / batches:.2f}s"
            )
        
except KeyboardInterrupt:
    print("Interrupted")
    dataloader.stop()
except Exception as e:
    print("Error", e)
    dataloader.stop()

Total recordings: 196
Batch 3732 (10 recordings) processed in 7.29s
Average processing time per recording: 0.73s
Average processing time per batch: 0.00s
Batch 10102 (20 recordings) processed in 17.08s
Average processing time per recording: 0.88s
Average processing time per batch: 0.00s
Batch 17724 (30 recordings) processed in 27.15s
Average processing time per recording: 0.91s
Average processing time per batch: 0.00s
Batch 23418 (40 recordings) processed in 37.23s
Average processing time per recording: 0.93s
Average processing time per batch: 0.00s
Batch 30068 (50 recordings) processed in 46.93s
Average processing time per recording: 0.94s
Average processing time per batch: 0.00s
Batch 35038 (60 recordings) processed in 49.15s
Average processing time per recording: 0.82s
Average processing time per batch: 0.00s
Batch 39827 (70 recordings) processed in 58.14s
Average processing time per recording: 0.83s
Average processing time per batch: 0.00s
Batch 48620 (80 recordings) processed in 7

[36m(raylet)[0m Spilled 2607 MiB, 6 objects, write throughput 1056 MiB/s. Set RAY_verbose_spill_logs=0 to disable this message.
[36m(raylet)[0m Spilled 4355 MiB, 9 objects, write throughput 1132 MiB/s.
[36m(raylet)[0m Spilled 16605 MiB, 42 objects, write throughput 1800 MiB/s.[32m [repeated 2x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)[0m
[36m(raylet)[0m Spilled 33088 MiB, 85 objects, write throughput 2009 MiB/s.
