In [2]:
from studies.study_factory import StudyFactory
from dataloader.dataloader import DataLoader

dataloader = DataLoader(
    buffer_size=10,
    max_cache_size_gb=100,
    cache_dir="cache",
    notch_filter=True,
    frequency_bands={"all": (0.5, 100)},
    scaling="both",
    brain_clipping=20,
    baseline_window=0.5,
    new_freq=100,
    batch_types={"audio": 3},
    batch_kwargs={
        'audio': {
            'max_random_shift': 1,
            'window_size': 4,
            'window_stride': 1, 
            'audio_sample_rate': 16000,
            'hop_length': 160,
            'audio_processor': "openai/whisper-large-v3"
        }
    },
)

2024-12-11 21:41:53,721	INFO worker.py:1821 -- Started a local Ray instance.


In [2]:
from models.simpleconv import SimpleConv
from config.simpleconv_config import SimpleConvConfig
import torch

config = SimpleConvConfig()
model = SimpleConv(config)

inputs = torch.randn(3, 208, 400)
layout = torch.randn(208, 2)
subjects = torch.tensor([0, 1, 2])

outputs = model(inputs, layout, subjects)
print(outputs.shape)


SimpleConv: 
	Params: 10496384
	Conv blocks: 5
	Trans layers: 0


RuntimeError: The size of tensor a (512) must match the size of tensor b (256) at non-singleton dimension 1

In [3]:
study = StudyFactory.create_study(
    study_name='gwilliams2023',
    batch_type='audio',
    path='data/gwilliams2023',
    cache_enabled=True,
    max_cache_size=200, # in items
)

Loading Gwilliams2023 with batch type audio


In [5]:
recordings = []

# Unfold all recordings (3 dim) of python list to 1
from itertools import chain
import random

flat_recordings = list(chain.from_iterable(chain.from_iterable(study.recordings)))
# random.shuffle(flat_recordings)

In [6]:
# # Start background fetching
import time


dataloader.start_fetching(flat_recordings, cache=True)

# Process batches as they become available
try:
    batches, recs, start_time = 0, 0, time.time()
    print(f'Total recordings: {len(flat_recordings)}')
    
    while True:
        batch = dataloader.get_recording()
        
        if batch is None:
            break
        
        brain = batch.brain_segments['all']
        
        print(f"Batch shape: {brain.shape[0]}")
        # batches += brain.shape[0]
        # recs += 1
        
        # if recs % 10 == 0:
        #     print(f"Batch {batches} ({recs} recordings) processed in {time.time() - start_time:.2f}s")
        #     print(
        #         f"Average processing time per recording: {(time.time() - start_time) / recs:.2f}s"
        #     )
        #     print(
        #         f"Average processing time per batch: {(time.time() - start_time) / batches:.2f}s"
        #     )
        
except KeyboardInterrupt:
    print("Interrupted")
    dataloader.stop()
except Exception as e:
    print("Error", e)
    dataloader.stop()

Total recordings: 196
Batch shape: 198
Batch shape: 425
Batch shape: 201
Batch shape: 735
Batch shape: 421
Batch shape: 199
Batch shape: 1083
Batch shape: 1068
Batch shape: 728
Batch shape: 1081
Batch shape: 423
Batch shape: 720
Batch shape: 196
Batch shape: 421
Batch shape: 731
Batch shape: 195
Batch shape: 421
Batch shape: 201
Batch shape: 1076
Batch shape: 1086
Batch shape: 726
Batch shape: 1084
Batch shape: 428
Batch shape: 733
Batch shape: 206
Batch shape: 424
Batch shape: 727
Batch shape: 203
Batch shape: 425
Batch shape: 203
Batch shape: 1087
Batch shape: 1083
Batch shape: 731
Batch shape: 1075
Batch shape: 418
Batch shape: 739
Batch shape: 200
Batch shape: 426
Batch shape: 730
Batch shape: 202
Batch shape: 417
Batch shape: 207
Batch shape: 1077
Batch shape: 1084
Batch shape: 727
Batch shape: 1076
Batch shape: 416
Batch shape: 720
Batch shape: 199
Batch shape: 424
Batch shape: 731
Batch shape: 203
Batch shape: 422
Batch shape: 195
Batch shape: 1082
Batch shape: 1085
Batch shape:

[36m(raylet)[0m Spilled 2950 MiB, 8 objects, write throughput 2177 MiB/s. Set RAY_verbose_spill_logs=0 to disable this message.
