In [4]:
import mne
import numpy as np

# Load a single run file
fif_path = "/srv/synaptech_openfmri/test/sub-06/run_03.fif"

# Load the raw data
raw = mne.io.read_raw_fif(fif_path, preload=True)

# Get basic info
print("\nFile Info:")
print(f"Path: {fif_path}")
print(f"Number of channels: {len(raw.ch_names)}")
print(f"Sampling frequency: {raw.info['sfreq']} Hz")
print(f"Duration: {raw.times.max():.2f} seconds")

# Get the data array
data = raw.get_data()
print("\nData Shape:")
print(f"Shape: {data.shape}")  # Should be (n_channels, n_timepoints)
print(f"Data type: {data.dtype}")

# Basic statistics
print("\nData Statistics:")
print(f"Mean: {np.mean(data):.4f}")
print(f"Std: {np.std(data):.4f}")
print(f"Min: {np.min(data):.4f}")
print(f"Max: {np.max(data):.4f}")

# Show first few channel names
print("\nFirst 5 channels:")
print(raw.ch_names[:5])

# Also check if there's a corresponding .txt file
import os
txt_path = fif_path.replace('.fif', '.txt')
if os.path.exists(txt_path):
    print("\nCorresponding .txt file exists:")
    with open(txt_path, 'r') as f:
        first_line = f.readline().strip()
    print(f"First line: {first_line}")


File Info:
Path: /srv/synaptech_openfmri/test/sub-06/run_03.fif
Number of channels: 404
Sampling frequency: 220.0 Hz
Duration: 497.50 seconds

Data Shape:
Shape: (404, 109450)
Data type: float64

Data Statistics:
Mean: 19.0903
Std: 387.2113
Min: -9.9427
Max: 7936.0000

First 5 channels:
['MEG0113', 'MEG0112', 'MEG0111', 'MEG0122', 'MEG0123']

Corresponding .txt file exists:
First line: ************************************************************


- mean-pooling

In [3]:


import os
import logging
import mne
import numpy as np
from tqdm import tqdm  # for progress bars
import warnings

warnings.filterwarnings(
    "ignore",
    message=".* does not conform to MNE naming conventions.*",
    category=RuntimeWarning,
)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
if not logger.handlers:
    handler = logging.StreamHandler()
    formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(message)s')
    handler.setFormatter(formatter)
    logger.addHandler(handler)

mne.set_log_level("WARNING")

_already_printed_shapes = False

def mean_pool_data(data, pool_size=5):
    """
    data: (n_channels, n_times)
    Returns data_pooled: (n_channels, n_times // pool_size)
         by averaging every 'pool_size' samples along time.
    """
    n_channels, n_times = data.shape
    remainder = n_times % pool_size
    if remainder != 0:
        # Truncate leftover frames
        data = data[:, : (n_times - remainder)]

    # Reshape => (n_channels, new_len, pool_size), then mean along last axis
    data = data.reshape(n_channels, -1, pool_size)  # now time dimension is n_times // pool_size
    data_pooled = data.mean(axis=-1)
    return data_pooled

def chunk_into_windows(data, window_size=275):
    """
    data: (n_channels, n_times_pooled)
    Returns windows: (n_windows, n_channels, window_size)
      by carving out consecutive blocks of size 'window_size'.
    """
    n_channels, n_times = data.shape
    n_windows = n_times // window_size
    usable = n_windows * window_size
    data = data[:, :usable]

    # Reshape => (n_channels, n_windows, window_size)
    data = data.reshape(n_channels, n_windows, window_size)
    # Transpose => (n_windows, n_channels, window_size)
    data = np.transpose(data, (1, 0, 2))
    return data

def pool_chunk_and_overwrite(fif_path, pool_size=5, window_size=275):
    """
    1) Load the original .fif as 'raw'.
    2) Mean-pool in blocks of 5 => reduces sampling rate by factor of 5.
    3) Chunk the pooled data into as many 275-frame windows as possible.
    4) Concatenate those windows along time => single time series.
    5) Make a *new* Raw object with that data & updated sfreq.
    6) Save (overwrite) at the same path => your original file is replaced.
    """
    # Load data
    logger.info(f"Processing & Overwriting: {fif_path}")
    raw_original = mne.io.read_raw_fif(fif_path, preload=True)
    original_data = raw_original.get_data()  # shape: (n_channels, n_times)
    original_sf = raw_original.info['sfreq']

    # Mean pool => shape: (n_channels, n_times // 5)
    data_pooled = mean_pool_data(original_data, pool_size=pool_size)
    
    # Break into 275-frame windows => shape: (n_windows, n_channels, 275)
    windows = chunk_into_windows(data_pooled, window_size=window_size)
    n_windows = windows.shape[0]

    # If there's at least one window, flatten back into a single time series
    # shape => (n_channels, n_windows * 275)
    if n_windows > 0:
        flattened_data = windows.transpose(1, 0, 2).reshape(original_data.shape[0], -1)
    else:
        # No windows fit => just store an empty or minimal dataset
        flattened_data = np.zeros((original_data.shape[0], 0), dtype=np.float32)

    # Print shape details for every file
    logger.info(
        f"  Shapes:\n"
        f"    Original: {original_data.shape},\n"
        f"    After pooling: {data_pooled.shape},\n"
        f"    #windows of {window_size} frames: {n_windows},\n"
        f"    Final 'flattened' shape: {flattened_data.shape}."
    )

    # Build a fresh Info with the same channels but new sfreq = original / pool_size
    ch_names = raw_original.ch_names
    ch_types = raw_original.get_channel_types()
    new_info = mne.create_info(ch_names=ch_names, sfreq=(original_sf / pool_size), ch_types=ch_types)
    
    # Create the new Raw object
    new_raw = mne.io.RawArray(flattened_data, new_info)
    # Keep measurement date
    new_raw.set_meas_date(raw_original.info['meas_date'])
    # Copy bad channels if desired
    new_raw.info['bads'] = list(raw_original.info['bads'])

    # Overwrite the original file with the new data
    new_raw.save(fif_path, overwrite=True)

def process_all(dataset_path="/srv/synaptech_openmri_1"):
    """
    1) Walk train/val/test subfolders.
    2) For each subject & each .fif:
       - Overwrite the original .fif with the new mean-pooled, multi-window data.
    """
    logger.info(f"Starting to process dataset at: {dataset_path}")
    
    for mode in ["train", "val", "test"]:
        mode_path = os.path.join(dataset_path, mode)
        if not os.path.isdir(mode_path):
            logger.warning(f"Skipping non-existent folder: {mode_path}")
            continue

        subjects = sorted(os.listdir(mode_path))
        logger.info(f"[{mode}] Found {len(subjects)} potential items in: {mode_path}")

        for subject in subjects:
            subject_path = os.path.join(mode_path, subject)
            if not os.path.isdir(subject_path):
                logger.debug(f"Skipping file (not folder): {subject}")
                continue

            fif_files = [f for f in os.listdir(subject_path) if f.endswith('.fif')]
            if not fif_files:
                logger.info(f"No .fif files found for subject {subject}, skipping.")
                continue

            logger.info(f"Subject {subject} has {len(fif_files)} .fif runs.")
            for run_file in fif_files:
                fif_path = os.path.join(subject_path, run_file)
                pool_chunk_and_overwrite(fif_path, pool_size=5, window_size=275)

    logger.info("Processing completed!")

# ---------------------------------------------------------------------------
# Execute
# ---------------------------------------------------------------------------
logger.info("Script started")
process_all("/srv/synaptech_openfmri")
logger.info("Script finished")


2024-12-24 00:29:23,293 [INFO] Script started
2024-12-24 00:29:23,294 [INFO] Starting to process dataset at: /srv/synaptech_openfmri
2024-12-24 00:29:23,294 [INFO] [train] Found 11 potential items in: /srv/synaptech_openfmri/train
2024-12-24 00:29:23,294 [INFO] Subject sub-01 has 6 .fif runs.
2024-12-24 00:29:23,294 [INFO] Processing & Overwriting: /srv/synaptech_openfmri/train/sub-01/run_06.fif
2024-12-24 00:29:25,419 [INFO]   Example shapes:
    Original: (404, 550000),
    After pooling: (404, 110000),
    #windows of 275 frames: 400,
    Final 'flattened' shape: (404, 110000).
2024-12-24 00:29:25,840 [INFO] Processing & Overwriting: /srv/synaptech_openfmri/train/sub-01/run_05.fif
2024-12-24 00:29:27,809 [INFO] Processing & Overwriting: /srv/synaptech_openfmri/train/sub-01/run_01.fif
2024-12-24 00:29:29,665 [INFO] Processing & Overwriting: /srv/synaptech_openfmri/train/sub-01/run_04.fif
2024-12-24 00:29:31,512 [INFO] Processing & Overwriting: /srv/synaptech_openfmri/train/sub-01/run

- making shards

In [43]:
import os
import logging
import warnings
import mne
import numpy as np
import torch  # for saving shards as .pt
from tqdm import tqdm

warnings.filterwarnings(
    "ignore",
    message=".* does not conform to MNE naming conventions.*",
    category=RuntimeWarning,
)

# ---------------------------------------------------------------------------
# Configure Python logger
# ---------------------------------------------------------------------------
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
if not logger.handlers:
    handler = logging.StreamHandler()
    formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(message)s')
    handler.setFormatter(formatter)
    logger.addHandler(handler)

# ---------------------------------------------------------------------------
# MNE logging
# ---------------------------------------------------------------------------
mne.set_log_level("WARNING")


def collect_global_stats(dataset_path="/srv/synaptech_openfmri"):
    """
    Pass 1:
      - Traverse all train/val/test .fif files,
      - Collect sums and sums_of_squares for each channel (EEG & MAG).
      - Then compute mean & std for each channel.

    Returns a dictionary with:
      {
        "eeg_mean": np.array of shape [n_eeg_channels],
        "eeg_std":  np.array of shape [n_eeg_channels],
        "mag_mean": np.array of shape [n_mag_channels],
        "mag_std":  np.array of shape [n_mag_channels]
      }

    We'll assume channel ordering is consistent across files and subjects.
    If not, you'd need a more robust approach (matching channel names, etc.).
    """
    logger.info("Collecting global channel stats (EEG & MAG) across the entire dataset...")

    # We'll store partial sums for EEG and MAG separately
    # We'll figure out the shape once we read the first file that has EEG, or MAG.
    eeg_channel_count = None
    mag_channel_count = None

    # We'll accumulate:
    #   sums_eeg, sums_sqr_eeg
    #   sums_mag, sums_sqr_mag
    # plus a global "num_samples_eeg" (the total # timepoints across all runs for EEG)
    sums_eeg = None
    sums_sqr_eeg = None
    num_samples_eeg = 0

    sums_mag = None
    sums_sqr_mag = None
    num_samples_mag = 0

    # Potential subfolders
    for mode in ["train", "val", "test"]:
        mode_path = os.path.join(dataset_path, mode)
        if not os.path.isdir(mode_path):
            continue

        subjects = sorted(os.listdir(mode_path))
        for subject in subjects:
            subject_path = os.path.join(mode_path, subject)
            if not os.path.isdir(subject_path):
                continue

            run_files = [f for f in os.listdir(subject_path) if f.endswith(".fif")]
            for run_file in run_files:
                run_path = os.path.join(subject_path, run_file)
                logger.debug(f"(Stats) Loading {run_path}")

                raw = mne.io.read_raw_fif(run_path, preload=True)
                data_all = raw.get_data()  # shape: [n_channels, n_times]

                # EEG
                eeg_indices = mne.pick_types(raw.info, meg=False, eeg=True)
                if len(eeg_indices) > 0:
                    # shape => [n_eeg, n_times]
                    eeg_data = data_all[eeg_indices, :]
                    n_eeg_ch, n_times = eeg_data.shape

                    if eeg_channel_count is None:
                        eeg_channel_count = n_eeg_ch
                        # Initialize accumulators
                        sums_eeg = np.zeros(eeg_channel_count, dtype=np.float64)
                        sums_sqr_eeg = np.zeros(eeg_channel_count, dtype=np.float64)

                    # Check that the number of channels matches the first assumption
                    if n_eeg_ch != eeg_channel_count:
                        logger.error(f"EEG channel count mismatch: {n_eeg_ch} vs. {eeg_channel_count}")
                        continue

                    # Accumulate
                    sums_eeg += np.sum(eeg_data, axis=1)
                    sums_sqr_eeg += np.sum(eeg_data**2, axis=1)
                    num_samples_eeg += n_times

                # MAG
                mag_indices = mne.pick_types(raw.info, meg='mag', eeg=False)
                if len(mag_indices) > 0:
                    mag_data = data_all[mag_indices, :]
                    n_mag_ch, n_times = mag_data.shape

                    if mag_channel_count is None:
                        mag_channel_count = n_mag_ch
                        sums_mag = np.zeros(mag_channel_count, dtype=np.float64)
                        sums_sqr_mag = np.zeros(mag_channel_count, dtype=np.float64)

                    if n_mag_ch != mag_channel_count:
                        logger.error(f"MAG channel count mismatch: {n_mag_ch} vs. {mag_channel_count}")
                        continue

                    sums_mag += np.sum(mag_data, axis=1)
                    sums_sqr_mag += np.sum(mag_data**2, axis=1)
                    num_samples_mag += n_times

    # Now compute means & std for EEG
    # shape => [n_eeg_channels], [n_eeg_channels]
    if eeg_channel_count is not None and num_samples_eeg > 0:
        mean_eeg = sums_eeg / num_samples_eeg
        # std = sqrt( E[X^2] - (E[X])^2 )
        # E[X^2] = sums_sqr / total_count
        e_x2 = sums_sqr_eeg / num_samples_eeg
        var_eeg = e_x2 - mean_eeg**2
        # guard against negative var from numerical issues
        var_eeg = np.maximum(var_eeg, 1e-20)
        std_eeg = np.sqrt(var_eeg)
    else:
        mean_eeg = np.array([])
        std_eeg = np.array([])

    # Now compute means & std for MAG
    if mag_channel_count is not None and num_samples_mag > 0:
        mean_mag = sums_mag / num_samples_mag
        e_x2 = sums_sqr_mag / num_samples_mag
        var_mag = e_x2 - mean_mag**2
        var_mag = np.maximum(var_mag, 1e-20)
        std_mag = np.sqrt(var_mag)
    else:
        mean_mag = np.array([])
        std_mag = np.array([])

    logger.info("Done gathering global stats.")
    return {
        "eeg_mean": mean_eeg,
        "eeg_std": std_eeg,
        "mag_mean": mean_mag,
        "mag_std": std_mag,
    }


def make_3d_windows(data, window_size=275, allow_padding=False, mode="EEG"):
    """
    Same as before: chunk data -> [n_channels, window_size, n_windows].
    """
    n_channels, n_times = data.shape
    n_windows = n_times // window_size
    leftover = n_times % window_size

    if n_windows == 0:
        if allow_padding and (n_times > 0):
            pad_window = np.zeros((n_channels, window_size), dtype=data.dtype)
            pad_window[:, :n_times] = data
            pad_window = pad_window.reshape(n_channels, window_size, 1)
            return torch.from_numpy(pad_window)
        else:
            logger.warning(f"Skipping {mode.upper()} because it has only {n_times} < {window_size} samples.")
            return None

    used = n_windows * window_size
    blocks = data[:, :used]
    blocks = blocks.reshape(n_channels, n_windows, window_size)
    blocks = np.transpose(blocks, (0, 2, 1))

    if leftover > 0 and allow_padding:
        pad_window = np.zeros((n_channels, window_size), dtype=data.dtype)
        pad_window[:, :leftover] = data[:, used:]
        pad_window = pad_window.reshape(n_channels, window_size, 1)
        blocks = np.concatenate([blocks, pad_window], axis=2)

    return torch.from_numpy(blocks)


def apply_zscore_and_save_shards(dataset_path, stats, window_size=275, shard_output_dir=None, allow_padding=False):
    """
    Pass 2:
      - Use the global stats (stats['eeg_mean'], stats['eeg_std'], etc.).
      - For each run, pick EEG => (data - mean) / std => chunk => save float16.
      - Same for MAG.
    """
    logger.info(f"Applying z-score normalization & saving shards in float16. Using window_size={window_size}")

    # Unpack stats
    mean_eeg = stats["eeg_mean"]
    std_eeg = stats["eeg_std"]
    mean_mag = stats["mag_mean"]
    std_mag = stats["mag_std"]

    for mode in ["train", "val", "test"]:
        mode_path = os.path.join(dataset_path, mode)
        if not os.path.isdir(mode_path):
            continue

        subjects = sorted(os.listdir(mode_path))
        for subject in subjects:
            subject_path = os.path.join(mode_path, subject)
            if not os.path.isdir(subject_path):
                continue

            run_files = [f for f in os.listdir(subject_path) if f.endswith(".fif")]
            for run_file in run_files:
                run_path = os.path.join(subject_path, run_file)
                logger.info(f"[Z-score pass] Loading {run_path}")

                raw = mne.io.read_raw_fif(run_path, preload=True)
                data_all = raw.get_data()

                # EEG
                eeg_indices = mne.pick_types(raw.info, meg=False, eeg=True)
                eeg_data = data_all[eeg_indices, :]
                if eeg_data.size > 0 and len(eeg_indices) == len(mean_eeg):
                    # Z-score each channel
                    # shape => [n_eeg, n_times]
                    for ch_i in range(len(eeg_indices)):
                        eeg_data[ch_i, :] = (eeg_data[ch_i, :] - mean_eeg[ch_i]) / std_eeg[ch_i]

                    shard_eeg = make_3d_windows(eeg_data, window_size=window_size, allow_padding=allow_padding, mode="EEG")
                    if shard_eeg is not None:
                        if shard_output_dir:
                            out_dir = os.path.join(shard_output_dir, mode, subject)
                        else:
                            out_dir = os.path.join(subject_path, "EEG_shard_zscored")
                        os.makedirs(out_dir, exist_ok=True)

                        name_base, _ = os.path.splitext(run_file)
                        out_fname = f"{name_base}_eeg.pt"
                        out_path = os.path.join(out_dir, out_fname)

                        # Convert to half
                        shard_eeg_fp16 = shard_eeg.half()
                        torch.save(shard_eeg_fp16, out_path)
                        logger.info(f"Z-scored EEG => shape={tuple(shard_eeg_fp16.shape)}: {out_path}")
                else:
                    logger.warning(f"EEG channel mismatch or no EEG data for {run_file}.")

                # MAG
                mag_indices = mne.pick_types(raw.info, meg='mag', eeg=False)
                mag_data = data_all[mag_indices, :]
                if mag_data.size > 0 and len(mag_indices) == len(mean_mag):
                    for ch_i in range(len(mag_indices)):
                        mag_data[ch_i, :] = (mag_data[ch_i, :] - mean_mag[ch_i]) / std_mag[ch_i]

                    shard_mag = make_3d_windows(mag_data, window_size=window_size, allow_padding=allow_padding, mode="MAG")
                    if shard_mag is not None:
                        if shard_output_dir:
                            out_dir = os.path.join(shard_output_dir, mode, subject)
                        else:
                            out_dir = os.path.join(subject_path, "MAG_shard_zscored")
                        os.makedirs(out_dir, exist_ok=True)

                        name_base, _ = os.path.splitext(run_file)
                        out_fname = f"{name_base}_mag.pt"
                        out_path = os.path.join(out_dir, out_fname)

                        shard_mag_fp16 = shard_mag.half()
                        torch.save(shard_mag_fp16, out_path)
                        logger.info(f"Z-scored MAG => shape={tuple(shard_mag_fp16.shape)}: {out_path}")
                else:
                    logger.warning(f"MAG channel mismatch or no MAG data for {run_file}.")


def main_zscore_shard_pipeline(dataset_path="/srv/synaptech_openfmri", window_size=275, shard_output_dir=None):
    """
    1) Gather global stats (means & std) for each EEG & MAG channel across entire dataset.
    2) Re-loop, apply z-score, chunk, save shards in float16.
    """
    logger.info("=== PASS 1: Gathering global stats ===")
    stats = collect_global_stats(dataset_path=dataset_path)
    logger.info("Computed global EEG means/stds shapes: "
                f"{stats['eeg_mean'].shape}, {stats['eeg_std'].shape}")
    logger.info("Computed global MAG means/stds shapes: "
                f"{stats['mag_mean'].shape}, {stats['mag_std'].shape}")

    logger.info("=== PASS 2: Applying z-score & saving shards ===")
    apply_zscore_and_save_shards(dataset_path=dataset_path,
                                 stats=stats,
                                 window_size=window_size,
                                 shard_output_dir=shard_output_dir,
                                 allow_padding=False)
    logger.info("Z-score pipeline complete.")


if __name__ == "__main__":
    main_zscore_shard_pipeline(
        dataset_path="/srv/synaptech_openfmri",
        window_size=275,
        shard_output_dir=None  # or e.g. "shards_zscore"
    )


2024-12-24 19:58:04,354 [INFO] === PASS 1: Gathering global stats ===
2024-12-24 19:58:04,355 [INFO] Collecting global channel stats (EEG & MAG) across the entire dataset...
2024-12-24 19:58:29,217 [INFO] Done gathering global stats.
2024-12-24 19:58:29,245 [INFO] Computed global EEG means/stds shapes: (74,), (74,)
2024-12-24 19:58:29,245 [INFO] Computed global MAG means/stds shapes: (102,), (102,)
2024-12-24 19:58:29,246 [INFO] === PASS 2: Applying z-score & saving shards ===
2024-12-24 19:58:29,246 [INFO] Applying z-score normalization & saving shards in float16. Using window_size=275
2024-12-24 19:58:29,246 [INFO] [Z-score pass] Loading /srv/synaptech_openfmri/train/sub-01/run_06.fif
2024-12-24 19:58:29,480 [INFO] Z-scored EEG => shape=(74, 275, 400): /srv/synaptech_openfmri/train/sub-01/EEG_shard_zscored/run_06_eeg.pt
2024-12-24 19:58:29,521 [INFO] Z-scored MAG => shape=(102, 275, 400): /srv/synaptech_openfmri/train/sub-01/MAG_shard_zscored/run_06_mag.pt
2024-12-24 19:58:29,521 [IN

### Cleanup
- removing shard folders

In [42]:
import os
import shutil
from pathlib import Path
import logging

# Set up logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
if not logger.handlers:
    handler = logging.StreamHandler()
    formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(message)s')
    handler.setFormatter(formatter)
    logger.addHandler(handler)

def cleanup_shard_folders(dataset_path="/srv/synaptech_openfmri"):
    """
    Remove all 'EEG_shard' and 'MAG_shard' folders from the dataset.
    """
    logger.info(f"Starting cleanup of shard folders in: {dataset_path}")
    
    folders_removed = 0
    
    # Walk through train/val/test folders
    for mode in ["train", "val", "test"]:
        mode_path = Path(dataset_path) / mode
        if not mode_path.is_dir():
            logger.warning(f"Skipping non-existent folder: {mode_path}")
            continue
            
        # For each subject folder
        for subject_path in mode_path.iterdir():
            if not subject_path.is_dir():
                continue
                
            # Check for and remove EEG_shard folder
            eeg_shard_path = subject_path / "EEG_shard"
            if eeg_shard_path.exists():
                logger.info(f"Removing: {eeg_shard_path}")
                shutil.rmtree(eeg_shard_path)
                folders_removed += 1
                
            # Check for and remove MAG_shard folder
            mag_shard_path = subject_path / "MAG_shard"
            if mag_shard_path.exists():
                logger.info(f"Removing: {mag_shard_path}")
                shutil.rmtree(mag_shard_path)
                folders_removed += 1
    
    logger.info(f"Cleanup completed! Removed {folders_removed} shard folders.")

if __name__ == "__main__":
    cleanup_shard_folders()

2024-12-24 19:57:59,640 [INFO] Starting cleanup of shard folders in: /srv/synaptech_openfmri
2024-12-24 19:57:59,641 [INFO] Cleanup completed! Removed 0 shard folders.


# Observability & Debugging

In [52]:
import torch
import os
from pathlib import Path
import logging

# ---------------------------------------------------------------------------
# Configure Python logger
# ---------------------------------------------------------------------------
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)  # Adjust to DEBUG for more verbosity
if not logger.handlers:
    handler = logging.StreamHandler()
    formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(message)s')
    handler.setFormatter(formatter)
    logger.addHandler(handler)

# ---------------------------------------------------------------------------
# Inspection Function
# ---------------------------------------------------------------------------
def inspect_all_shards(dataset_path="/srv/synaptech_openfmri", per_channel_stats=True):
    """
    Inspect all EEG and MAG shards across all modes and subjects.
    Prints global and per-channel statistics to verify z-score normalization.
    """
    logger.info(f"Starting inspection of all shards in dataset: {dataset_path}")
    
    modes = ["train", "val", "test"]
    shard_types = ["EEG_shard_zscored", "MAG_shard_zscored"]
    
    for mode in modes:
        mode_path = Path(dataset_path) / mode
        if not mode_path.is_dir():
            logger.warning(f"Skipping non-existent mode folder: {mode_path}")
            continue
        
        subjects = sorted(mode_path.iterdir())
        logger.info(f"Inspecting mode: {mode} with {len(subjects)} subjects.")
        
        for subject_path in subjects:
            if not subject_path.is_dir():
                logger.debug(f"Skipping non-directory: {subject_path}")
                continue
            
            for shard_type in shard_types:
                shard_dir = subject_path / shard_type
                if not shard_dir.is_dir():
                    logger.warning(f"Missing shard directory: {shard_dir}, skipping.")
                    continue
                
                shard_files = sorted(shard_dir.glob("*.pt"))
                logger.info(f"Inspecting {len(shard_files)} shards in {shard_dir}")
                
                for shard_file in shard_files:
                    try:
                        # Suppress the FutureWarning by specifying weights_only=True if appropriate
                        # Since these are custom shards, ensure you are loading them securely
                        tensor = torch.load(shard_file, weights_only=True)
                    except TypeError:
                        # If weights_only is not a valid argument (older PyTorch versions), load normally
                        tensor = torch.load(shard_file)
                    
                    print(f"\n{mode} | {subject_path.name} | {shard_type} | {shard_file.name}:")
                    print(f"  Shape: {tensor.shape}")
                    print(f"  Dtype: {tensor.dtype}")
                    print(f"  Device: {tensor.device}")
    
                    with torch.no_grad():
                        # Convert to float32 for accurate statistics without altering the saved data
                        data_f32 = tensor.float()
                        
                        global_min = data_f32.min().item()
                        global_max = data_f32.max().item()
                        global_mean = data_f32.mean().item()
                        global_std = data_f32.std().item()
                    
                    print(f"  (Global) Min:  {global_min:.6e}")
                    print(f"  (Global) Max:  {global_max:.6e}")
                    print(f"  (Global) Mean: {global_mean:.6e}")
                    print(f"  (Global) Std:  {global_std:.6e}")
    
                    # Optional per-channel stats
                    if per_channel_stats:
                        n_channels = data_f32.shape[0]
                        # Reshape to (n_channels, -1) without altering memory layout
                        data_2d = data_f32.reshape(n_channels, -1)
                        
                        print("  Per-channel stats:")
                        for ch_idx in range(n_channels):
                            ch_data = data_2d[ch_idx]
                            ch_min = ch_data.min().item()
                            ch_max = ch_data.max().item()
                            ch_mean = ch_data.mean().item()
                            ch_std = ch_data.std().item()
                            print(f"    Channel {ch_idx:02d}: "
                                  f"Min={ch_min:.4e}, Max={ch_max:.4e}, "
                                  f"Mean={ch_mean:.4e}, Std={ch_std:.4e}")
    
    logger.info("Completed inspection of all shards.")

# ---------------------------------------------------------------------------
# Main Execution
# ---------------------------------------------------------------------------
if __name__ == "__main__":
    inspect_all_shards(
        dataset_path="/srv/synaptech_openfmri",
        per_channel_stats=True  # Set to False to skip per-channel statistics
    )


2024-12-24 23:27:03,055 [INFO] Starting inspection of all shards in dataset: /srv/synaptech_openfmri
2024-12-24 23:27:03,055 [INFO] Collecting data from mode: train with 11 subjects.



train | sub-01 | EEG_shard_zscored | run_01_eeg.pt:
  Shape: torch.Size([74, 275, 392])
  Dtype: torch.float16
  Device: cpu

train | sub-01 | EEG_shard_zscored | run_02_eeg.pt:
  Shape: torch.Size([74, 275, 397])
  Dtype: torch.float16
  Device: cpu

train | sub-01 | EEG_shard_zscored | run_03_eeg.pt:
  Shape: torch.Size([74, 275, 404])
  Dtype: torch.float16
  Device: cpu

train | sub-01 | EEG_shard_zscored | run_04_eeg.pt:
  Shape: torch.Size([74, 275, 394])
  Dtype: torch.float16
  Device: cpu

train | sub-01 | EEG_shard_zscored | run_05_eeg.pt:
  Shape: torch.Size([74, 275, 404])
  Dtype: torch.float16
  Device: cpu

train | sub-01 | EEG_shard_zscored | run_06_eeg.pt:
  Shape: torch.Size([74, 275, 400])
  Dtype: torch.float16
  Device: cpu

train | sub-01 | MAG_shard_zscored | run_01_mag.pt:
  Shape: torch.Size([102, 275, 392])
  Dtype: torch.float16
  Device: cpu

train | sub-01 | MAG_shard_zscored | run_02_mag.pt:
  Shape: torch.Size([102, 275, 397])
  Dtype: torch.float16
  De

2024-12-24 23:27:03,680 [INFO] Collecting data from mode: val with 3 subjects.
2024-12-24 23:27:03,831 [INFO] Collecting data from mode: test with 2 subjects.



train | sub-16 | MAG_shard_zscored | run_04_mag.pt:
  Shape: torch.Size([102, 275, 397])
  Dtype: torch.float16
  Device: cpu

train | sub-16 | MAG_shard_zscored | run_05_mag.pt:
  Shape: torch.Size([102, 275, 394])
  Dtype: torch.float16
  Device: cpu

train | sub-16 | MAG_shard_zscored | run_06_mag.pt:
  Shape: torch.Size([102, 275, 393])
  Dtype: torch.float16
  Device: cpu

val | sub-03 | EEG_shard_zscored | run_01_eeg.pt:
  Shape: torch.Size([74, 275, 404])
  Dtype: torch.float16
  Device: cpu

val | sub-03 | EEG_shard_zscored | run_03_eeg.pt:
  Shape: torch.Size([74, 275, 403])
  Dtype: torch.float16
  Device: cpu

val | sub-03 | EEG_shard_zscored | run_04_eeg.pt:
  Shape: torch.Size([74, 275, 397])
  Dtype: torch.float16
  Device: cpu

val | sub-03 | EEG_shard_zscored | run_05_eeg.pt:
  Shape: torch.Size([74, 275, 395])
  Dtype: torch.float16
  Device: cpu

val | sub-03 | EEG_shard_zscored | run_06_eeg.pt:
  Shape: torch.Size([74, 275, 400])
  Dtype: torch.float16
  Device: cpu

2024-12-24 23:27:05,588 [INFO] Completed inspection of all shards.


    Channel 60: Min=-1.1973e+00, Max=4.8218e-01, Mean=1.9102e-09, Std=8.1178e-02
    Channel 61: Min=-1.1807e+00, Max=2.6294e-01, Mean=-8.5961e-10, Std=5.8069e-02
    Channel 62: Min=-1.3955e+00, Max=3.6475e-01, Mean=-1.7192e-09, Std=6.5591e-02
    Channel 63: Min=-8.8770e-01, Max=2.4072e-01, Mean=-7.6887e-09, Std=6.0809e-02
    Channel 64: Min=-1.3125e+00, Max=2.0813e-01, Mean=-7.5454e-09, Std=8.0615e-02
    Channel 65: Min=-1.4883e+00, Max=3.5962e-01, Mean=1.9102e-09, Std=6.9912e-02
    Channel 66: Min=-7.2803e-01, Max=3.6548e-01, Mean=-6.8769e-09, Std=8.9832e-02
    Channel 67: Min=-5.6006e-01, Max=4.9634e-01, Mean=7.6410e-10, Std=7.3214e-02
    Channel 68: Min=-5.7422e-01, Max=7.9297e-01, Mean=8.0230e-09, Std=6.8546e-02
    Channel 69: Min=-6.9385e-01, Max=7.5635e-01, Mean=3.4384e-09, Std=1.0579e-01
    Channel 70: Min=-1.2256e+00, Max=6.7676e-01, Mean=5.9217e-09, Std=8.8449e-02
    Channel 71: Min=-1.3242e+00, Max=1.0479e+00, Mean=3.1519e-09, Std=5.7632e-02
    Channel 72: Min=-1.

In [18]:
eeg_data = raw_data.get_data(picks='eeg')
meg_data = raw_data.get_data(picks='meg')
print("Shape of eeg_data:", eeg_data.shape)
print("Shape of meg_data:", meg_data.shape)

print("Type of eeg_data:", type(eeg_data))
print("Type of meg_data:", type(meg_data))

print(eeg_data.nbytes / (1000**3), "GB")
print(meg_data.nbytes / (1000**3), "GB")
total_data_size_gb = (eeg_data.nbytes + meg_data.nbytes) / (1000**3)
print("Total data size for participant 1:", total_data_size_gb, "GB")

total_participant_count = 19
estimated_total_data_size_gb = total_data_size_gb * total_participant_count
print("Estimated total openFMRI data size for all participants:", estimated_total_data_size_gb, "GB")



Shape of eeg_data: (74, 540100)
Shape of meg_data: (306, 540100)
Type of eeg_data: <class 'numpy.ndarray'>
Type of meg_data: <class 'numpy.ndarray'>
0.3197392 GB
1.3221648 GB
Total data size for participant 1: 1.641904 GB
Estimated total openFMRI data size for all participants: 31.196176 GB
