In [1]:
from torch.utils.data import Dataset
import os
import logging
import traceback
import shutil
import zipfile
import random
import re
import numpy as np
import contextlib
import mne

In [2]:
# Setting up the logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [3]:
def move_contents_to_parent_and_delete_sub(subfolder, parentfolder, expected_content_count=None):
    """
    Moves all contents from a subfolder to its parent folder and deletes the subfolder.

    Parameters:
        subfolder (str): The name (relative path) of the subfolder whose contents are to be moved.
        parentfolder (str): The absolute path to the parent folder where contents will be moved.
        expected_content_count (int, optional): Expected number of items in the subfolder.

    Returns:
        None
    """
    subfolder_path = os.path.join(parentfolder, subfolder)
    if expected_content_count is not None:
        actual_content_count = len(os.listdir(subfolder_path))
        assert actual_content_count == expected_content_count, (
            f"Expected {expected_content_count} items, but found {actual_content_count} in {subfolder_path}"
        )
    # Move contents of subfolder to parentfolder
    for item in os.listdir(subfolder_path):
        item_path = os.path.join(subfolder_path, item)
        shutil.move(item_path, parentfolder)
    
    # Delete all folders between parentfolder and subfolder
    current_path = subfolder_path
    while current_path != parentfolder:
        parent_path = os.path.dirname(current_path)
        os.rmdir(current_path)
        current_path = parent_path

In [4]:
def unzip_and_rename_in_folder(folder, remove=False):
    """
    Unzips all zip files in the specified folder and renames the extracted folders.

    Parameters:
        folder (str): The path to the folder containing zip files to be unzipped.
        remove (bool): If True, the zip files will be deleted after extraction.

    Returns:
        None
    """
    unzipped_marker = os.path.join(folder, '.unzipped')
    if os.path.exists(unzipped_marker):
        logger.info(f"Folder {folder} is already unzipped. Exiting early.")
        return

    assert all(f.endswith('.zip') or f.startswith('.') for f in os.listdir(folder)), (
        f"Not all files in {folder} are zip files or ignored files. Please delete non-zip files and re-run."
    )
    zip_file_count = sum(1 for f in os.listdir(folder) if f.endswith('.zip'))
    logger.info(f"Unzipping and renaming {zip_file_count} files in folder: {folder}")
    for zip_file in os.listdir(folder):
        zip_file_path = os.path.join(folder, zip_file)
        if zip_file.endswith('.zip'):
            logger.info(f"Unzipping {zip_file_path}...")
            try:
                with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
                    zip_ref.extractall(os.path.dirname(zip_file_path))
                if remove:
                    os.remove(zip_file_path)  # Remove the zip file after extraction
                    logger.info(f"Unzipped and removed {zip_file_path}")
                else:
                    logger.info(f"Unzipped {zip_file_path}")
            except Exception as e:
                logger.error(f"Error unzipping {zip_file_path}, {e}")
                traceback.print_exc()
    for i, unzipped_folder in enumerate(
        f for f in os.listdir(folder) if not f.endswith('.zip') and not f.startswith('.')
    ):
        from_name = os.path.join(folder, unzipped_folder)
        to_name = os.path.join(folder, f'folder_{i}')
        os.rename(from_name, to_name)

    # Create the .unzipped marker file
    with open(unzipped_marker, 'w') as marker_file:
        marker_file.write('')

In [5]:
def arrange_folders(dataset_folder):
    """
    Rearranges the files in the dataset folder after downloading and unzipping.

    This function organizes the dataset by moving participant folders into the main dataset folder,
    ensuring that relevant files are not nested within subdirectories.

    Parameters:
        dataset_folder (str): The path to the main dataset folder.

    Returns:
        None
    """
    arranged_marker = os.path.join(dataset_folder, '.arranged')
    if os.path.exists(arranged_marker):
        logger.info(f"Dataset folder {dataset_folder} is already arranged. Exiting early.")
        return

    participant_count = 0
    unzipped_folders = [
        f for f in os.listdir(dataset_folder) if os.path.isdir(os.path.join(dataset_folder, f))
    ]
    for folder in unzipped_folders:
        # Move participant folders into dataset folder (so they are not nested)
        desired_subfolder = os.path.join(folder, 'derivatives', 'meg_derivatives')
        move_contents_to_parent_and_delete_sub(desired_subfolder, dataset_folder)
    
    # Go through participant folders and move nested relevant files up
    participant_folders = [
        f for f in os.listdir(dataset_folder) if not f.startswith('.') and not f.endswith('.zip')
    ]
    for participant_folder in participant_folders:
        participant_count += 1
        desired_subfolder = os.path.join('ses-meg', 'meg')
        participant_folder_path = os.path.join(dataset_folder, participant_folder)
        move_contents_to_parent_and_delete_sub(desired_subfolder, participant_folder_path)

    # Rename fif files to standardized format
    for participant_folder in participant_folders: 
        participant_folder_path = os.path.join(dataset_folder, participant_folder)
        for file in os.listdir(participant_folder_path):
            match = re.match(r'.*run-(0[1-6]).*', file)
            if match:
                new_file_name = f'run_{match.group(1)}' + os.path.splitext(file)[1]
                os.rename(
                    os.path.join(participant_folder_path, file),
                    os.path.join(participant_folder_path, new_file_name)
                )

    # Ensure the dataset folder contains the expected number of participant folders
    non_zip_non_dot_folders = [
        f for f in os.listdir(dataset_folder) if not f.startswith('.') and not f.endswith('.zip')
    ]
    assert len(non_zip_non_dot_folders) == participant_count, (
        f"ERROR: Dataset folder contains {len(non_zip_non_dot_folders)} folders, but "
        f"{participant_count} (participant) folders are expected."
    )

    # Create the .arranged marker file
    with open(arranged_marker, 'w') as marker_file:
        marker_file.write('')

In [6]:
def randomize_subject_data(dataset_folder, train_percentage=70, val_percentage=20):
    """
    Randomizes and splits participant data into training, validation, and test sets.

    Parameters:
        dataset_folder (str): The path to the main dataset folder.
        train_percentage (int): The percentage of data allocated to the training set.
        val_percentage (int): The percentage of data allocated to the validation set.

    Returns:
        None
    """
    participant_folders = [
        f for f in os.listdir(dataset_folder) if os.path.isdir(os.path.join(dataset_folder, f))
    ]
    random.seed(42)
    random.shuffle(participant_folders)

    total_participants = len(participant_folders)
    train_count = int(total_participants * train_percentage / 100)
    val_count = int(total_participants * val_percentage / 100)

    train_folder = os.path.join(dataset_folder, 'train')
    val_folder = os.path.join(dataset_folder, 'val')
    test_folder = os.path.join(dataset_folder, 'test')

    os.makedirs(train_folder, exist_ok=True)
    os.makedirs(val_folder, exist_ok=True)
    os.makedirs(test_folder, exist_ok=True)

    for i, participant_folder in enumerate(participant_folders):
        participant_folder_path = os.path.join(dataset_folder, participant_folder)
        if i < train_count:
            shutil.move(participant_folder_path, train_folder)
        elif i < train_count + val_count:
            shutil.move(participant_folder_path, val_folder)
        else:
            shutil.move(participant_folder_path, test_folder)

In [7]:
def count_frames(dataset_path, mode='train'):
    """
    Counts the number of frames in the dataset and returns the frame count.

    Parameters:
        dataset_path (str): The path to the main dataset folder.
        mode (str): The dataset mode ('train', 'val', or 'test').

    Returns:
        tuple: (frame_count, eeg_channel_count, meg_channel_count)
    """
    mode_path = os.path.join(dataset_path, mode)
    frame_count_file = os.path.join(mode_path, '.frameCount')

    if os.path.exists(frame_count_file):
        with open(frame_count_file, 'r') as f:
            frame_count = int(f.read().strip())
        logger.info(f"Loaded cached frame count: {frame_count}")
    else:
        eeg_frame_count = 0
        meg_frame_count = 0
        eeg_channels = 0
        meg_channels = 0

        participant_folders = [
            f for f in os.listdir(mode_path)
            if not f.startswith('.') and not f.endswith('.zip') and os.path.isdir(os.path.join(mode_path, f))
        ]
        
        logger.info(f"Counting frames in {mode} dataset...")
        for participant_folder in participant_folders:
            fif_files = [
                file for file in os.listdir(os.path.join(mode_path, participant_folder))
                if file.endswith('.fif')
            ]
            for fif in fif_files:
                fif_path = os.path.join(mode_path, participant_folder, fif)
                try:
                    with contextlib.redirect_stdout(None), contextlib.redirect_stderr(None):
                        raw_data = mne.io.read_raw_fif(fif_path, preload=False, verbose=False)
                except Exception as e:
                    logger.error(f"Error reading {fif_path}, {e}")
                    raise e
                eeg_data = raw_data.get_data(picks='eeg')
                meg_data = raw_data.get_data(picks='meg')
                eeg_frame_count += eeg_data.shape[1]
                meg_frame_count += meg_data.shape[1]
                eeg_channels = eeg_data.shape[0]
                meg_channels = meg_data.shape[0]
        assert eeg_frame_count == meg_frame_count, "EEG and MEG frame count not equal."
        frame_count = eeg_frame_count
        logger.info(f"Successfully counted {frame_count} frames in {mode} dataset...")

        with open(frame_count_file, 'w') as f:
            f.write(str(frame_count))

    return frame_count, eeg_channels, meg_channels

In [8]:
def calculate_chunk_number(frame_count, eeg_channel_count, meg_channel_count,
                           memory_limit='32', vram_limit='16',
                           memory_chunk_size='4', vram_chunk_size='8',
                           eeg_values_type=np.float16, meg_values_type=np.float16):
    """
    Calculates and logs the number of chunks for memory and VRAM usage.

    Parameters:
        frame_count (int): Total number of frames.
        eeg_channel_count (int): Number of EEG channels.
        meg_channel_count (int): Number of MEG channels.
        memory_limit (str): Memory limit in GB.
        vram_limit (str): VRAM limit in GB.
        memory_chunk_size (str): Memory chunk size in GB.
        vram_chunk_size (str): VRAM chunk size in GB.
        eeg_values_type (dtype): Data type for EEG values.
        meg_values_type (dtype): Data type for MEG values.

    Returns:
        None
    """
    logger.info("Calculating chunk sizes...")
    eeg_value_type_size = np.dtype(eeg_values_type).itemsize
    total_eeg_size = eeg_channel_count * frame_count * eeg_value_type_size

    meg_value_type_size = np.dtype(meg_values_type).itemsize
    total_meg_size = meg_channel_count * frame_count * meg_value_type_size

    total_size = total_eeg_size + total_meg_size
    
    logger.info(f"Size of EEG and MEG data in dataset: {total_size / (1024**3):.2f} GB")

    memory_chunk_size_bytes = int(memory_chunk_size) * (1024 ** 3)
    vram_chunk_size_bytes = int(vram_chunk_size) * (1024 ** 3)

    memory_chunk_number = total_size // memory_chunk_size_bytes
    vram_chunk_number = total_size // vram_chunk_size_bytes

    logger.info(f"Memory Chunk Number: {memory_chunk_number}")
    logger.info(f"VRAM Chunk Number: {vram_chunk_number}")

In [9]:
dataset_path = os.path.join('data', 'openfmri')
default_download_urls = [  # Links to normalized data of all participants
    # "https://s3.amazonaws.com/openneuro/ds000117/ds000117_R1.0.0/compressed/ds000117_R1.0.0_derivatives_sub01-04.zip",
    # "https://s3.amazonaws.com/openneuro/ds000117/ds000117_R1.0.0/compressed/ds000117_R1.0.0_derivatives_sub05-08.zip",
    # "https://s3.amazonaws.com/openneuro/ds000117/ds000117_R1.0.0/compressed/ds000117_R1.0.0_derivatives_sub09-12.zip",
    # "https://s3.amazonaws.com/openneuro/ds000117/ds000117_R1.0.0/compressed/ds000117_R1.0.0_derivatives_sub13-16.zip"
]

def download_dataset(dataset_path, download_urls):
    """
    Downloads the dataset from the specified URLs to the given dataset path.

    Parameters:
        dataset_path (str): The directory path where the dataset will be downloaded.
        download_urls (list): A list of URLs from which to download the dataset.

    Returns:
        None
    """
    if not os.path.exists(dataset_path):
        os.makedirs(dataset_path)
    
    downloaded_marker = os.path.join(dataset_path, '.downloaded')
    
    if not os.path.exists(downloaded_marker):
        logger.info(f"Downloading {len(download_urls)} files...")
        for url in download_urls:
            file_name = os.path.join(dataset_path, url.split('/')[-1])
            if not os.path.exists(file_name):
                logger.info(f"Downloading {file_name}...")
                try:
                    os.system(f"wget -O {file_name} {url}")
                except Exception as e:
                    logger.error(f"Error downloading file {file_name}, {e}")
                    traceback.print_exc()
            else:
                logger.info(f"{file_name} already exists, skipping download.")
        
        # Check if all files are downloaded
        downloaded_files = [f for f in os.listdir(dataset_path) if f.endswith('.zip')]
        if len(downloaded_files) == len(download_urls):
            with open(downloaded_marker, 'w') as f:
                f.write('Download completed successfully.')
            logger.info(f"Successfully downloaded {len(downloaded_files)} files.")
    else:
        logger.info("Dataset already downloaded. Skipping download.")

# Function execution

In [None]:
# download_dataset(dataset_path, default_download_urls)

In [None]:
# unzip_and_rename_in_folder(dataset_path)

In [None]:
# 3. Arrange Folders
arrange_folders(dataset_path)

In [None]:
# 4. Count Frames in the dataset
randomize_subject_data(dataset_path, train_percentage=70, val_percentage=20)

In [None]:
# 5. Randomze and split subject data
randomize_subject_data(dataset_path, train_percentage=70, val_percentage=20)

In [None]:
# 6.  Count Frames in the dataset
frame_count, eeg_channel_count, meg_channel_count = count_frames(dataset_path, mode='train')
logger.info(f"Frame Count: {frame_count}")
logger.info(f"EEG Channel Count: {eeg_channel_count}")
logger.info(f"MEG Channel Count: {meg_channel_count}")

In [None]:
# Calculate Chunk Number
calculate_chunk_number(
    frame_count,
    eeg_channel_count,
    meg_channel_count,
    memory_limit='32',
    vram_limit='16',
    memory_chunk_size='4',
    vram_chunk_size='8',
    eeg_values_type=np.float16,
    meg_values_type=np.float16
)