In [22]:
import logging
from dataset.shard_loader import ShardDataLoader
from torch.utils.data import DataLoader

def main():
    # Set up logging
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(__name__)

    # Path to your dataset directory (replace with your actual dataset path)
    dataset_path = '/home/gleb/Synaptech/data_magic/data/openfmri/'

    # Create an instance of the ShardDataLoader
    shard_loader = ShardDataLoader(
        dataset_path=dataset_path,
        mode='train',
        window_size=275,
        verbose=True
    )

    # Prepare the dataset for an epoch
    sample_length = 275  # Adjust as needed
    epoch_dataset = shard_loader.prepare_epoch_dataset(sample_length=sample_length)

    # Create a DataLoader
    batch_size = 32
    data_loader = DataLoader(
        epoch_dataset,
        batch_size=batch_size,
        shuffle=True  # Shuffle the samples within the batches
    )

    # Iterate over a few batches and print their shapes
    for batch_idx, (eeg_batch, mag_batch) in enumerate(data_loader):
        logger.info(f"Batch {batch_idx + 1}:")
        logger.info(f"  EEG batch shape: {eeg_batch.shape}")
        logger.info(f"  MAG batch shape: {mag_batch.shape}")
        
        # Optionally print some sample data
        # logger.info(f"  EEG batch data: {eeg_batch}")
        # logger.info(f"  MAG batch data: {mag_batch}")

if __name__ == "__main__":
    main()

INFO:dataset.shard_loader:Loaded 65 shard pairs for mode 'train'
2024-12-26 01:27:57,674 [INFO] Batch 1:
INFO:__main__:Batch 1:
2024-12-26 01:27:57,675 [INFO]   EEG batch shape: torch.Size([32, 74, 275])
INFO:__main__:  EEG batch shape: torch.Size([32, 74, 275])
2024-12-26 01:27:57,675 [INFO]   MAG batch shape: torch.Size([32, 102, 275])
INFO:__main__:  MAG batch shape: torch.Size([32, 102, 275])
2024-12-26 01:27:57,691 [INFO] Batch 2:
INFO:__main__:Batch 2:
2024-12-26 01:27:57,692 [INFO]   EEG batch shape: torch.Size([32, 74, 275])
INFO:__main__:  EEG batch shape: torch.Size([32, 74, 275])
2024-12-26 01:27:57,692 [INFO]   MAG batch shape: torch.Size([32, 102, 275])
INFO:__main__:  MAG batch shape: torch.Size([32, 102, 275])
2024-12-26 01:27:57,693 [INFO] Batch 3:
INFO:__main__:Batch 3:
2024-12-26 01:27:57,693 [INFO]   EEG batch shape: torch.Size([32, 74, 275])
INFO:__main__:  EEG batch shape: torch.Size([32, 74, 275])
2024-12-26 01:27:57,694 [INFO]   MAG batch shape: torch.Size([32, 1

### Cleanup Shards
- removing shard folders

In [None]:
import os
import shutil
from pathlib import Path
import logging

# Set up logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
if not logger.handlers:
    handler = logging.StreamHandler()
    formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(message)s')
    handler.setFormatter(formatter)
    logger.addHandler(handler)

def cleanup_shard_folders(dataset_path="/home/gleb/Synaptech/data_magic/data/openfmri"):
    """
    Remove all 'EEG_shard' and 'MAG_shard' folders from the dataset.
    """
    logger.info(f"Starting cleanup of shard folders in: {dataset_path}")
    
    folders_removed = 0
    
    # Walk through train/val/test folders
    for mode in ["train", "val", "test"]:
        mode_path = Path(dataset_path) / mode
        if not mode_path.is_dir():
            logger.warning(f"Skipping non-existent folder: {mode_path}")
            continue
            
        # For each subject folder
        for subject_path in mode_path.iterdir():
            if not subject_path.is_dir():
                continue
                
            # Check for and remove EEG_shard folder
            eeg_shard_path = subject_path / "EEG_shards"
            if eeg_shard_path.exists():
                logger.info(f"Removing: {eeg_shard_path}")
                shutil.rmtree(eeg_shard_path)
                folders_removed += 1
                
            # Check for and remove MAG_shard folder
            mag_shard_path = subject_path / "MAG_shards"
            if mag_shard_path.exists():
                logger.info(f"Removing: {mag_shard_path}")
                shutil.rmtree(mag_shard_path)
                folders_removed += 1
    
    logger.info(f"Cleanup completed! Removed {folders_removed} shard folders.")

if __name__ == "__main__":
    cleanup_shard_folders()