In [1]:
from pathlib import Path
import sys

# More reliable: get the project root from the notebook's location
ROOT_PATH = Path(__file__).parent.parent if '__file__' in globals() else Path.cwd().parent
# Or even better for notebooks:
ROOT_PATH = Path().resolve().parent  # Goes up from notebooks/ folder

sys.path.append(str(ROOT_PATH))

from src.data.dataset import OrionAEFrameDataset

In [2]:
train_set = OrionAEFrameDataset(
    data_path=r"C:\Users\nguye\Documents\GitHub\orion-ae-study\data\raw\segmented_ms_30_0_o_0_00_c_A_B_C_D_20251213_092549",
    config_path=r"C:\Users\nguye\Documents\GitHub\orion-ae-study\configs\dataset\example_1.yaml",
    type="train"
)

val_set = OrionAEFrameDataset(
    data_path=r"C:\Users\nguye\Documents\GitHub\orion-ae-study\data\raw\segmented_ms_30_0_o_0_00_c_A_B_C_D_20251213_092549",
    config_path=r"C:\Users\nguye\Documents\GitHub\orion-ae-study\configs\dataset\example_1.yaml",
    type="val"
)

test_set = OrionAEFrameDataset(
    data_path=r"C:\Users\nguye\Documents\GitHub\orion-ae-study\data\raw\segmented_ms_30_0_o_0_00_c_A_B_C_D_20251213_092549",
    config_path=r"C:\Users\nguye\Documents\GitHub\orion-ae-study\configs\dataset\example_1.yaml",
    type="test"
)


In [3]:
len(test_set)

2301

In [4]:
test_set[0]['raw']

array([[ -3.295999 ,  -4.3946652,  -2.1973326, ...,  -9.887997 ,
         -9.887997 ,  -7.6906643],
       [  4.333628 ,   0.       ,  -3.295999 , ...,   4.333628 ,
          7.56859  ,  10.86459  ],
       [  5.432295 ,   2.1362956,   3.234962 , ..., -18.555254 ,
        -20.69155  , -21.790216 ]], shape=(3, 150000), dtype=float32)

In [5]:
import numpy as np

np.mean(test_set[0]['raw'], axis = 1)
np.std(test_set[0]['raw'], axis = 1)

array([5.04889 , 5.631505, 7.921509], dtype=float32)

In [6]:
np.sum((np.array([1,2,3]), np.array([3,4,5])), axis=0)

array([4, 6, 8])

In [7]:
import numpy as np
from tqdm import tqdm
from typing import Optional, Tuple, List

def calculate_norm_params(
    datasets: List,
    target_label: int,
    target_serie: str,
    show_progress: bool = True
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], dict]:
    """
    Calculate mean and std normalization parameters for a specific label and series.
    
    Uses memory-efficient running sums to avoid storing all data in RAM.
    
    Args:
        datasets: List of dataset objects (e.g., [train_set, val_set, test_set])
        target_label: Target label to filter by
        target_serie: Target series name to filter by
        show_progress: Whether to show progress bars
    
    Returns:
        Tuple of (means, stds, info_dict) where:
            - means: Array of means per channel, or None if no items found
            - stds: Array of stds per channel, or None if no items found
            - info_dict: Dictionary with 'item_count' and 'total_count'
    """
    num_channels = None
    sum_values = None
    sum_squared = None
    total_count = 0
    item_count = 0
    
    # Iterate through all datasets and accumulate statistics
    for dataset in datasets:
        iterator = dataset
        if show_progress:
            dataset_name = getattr(dataset, 'type', 'dataset')
            iterator = tqdm(dataset, desc=f"Processing {dataset_name}", leave=False)
        
        for item in iterator:
            if item['label'] == target_label and item['serie'] == target_serie:
                raw_data = item['raw']  # Shape: (channels, time_steps)
                
                # Initialize accumulators on first item
                if num_channels is None:
                    num_channels = raw_data.shape[0]
                    sum_values = np.zeros(num_channels, dtype=np.float64)
                    sum_squared = np.zeros(num_channels, dtype=np.float64)
                
                # Accumulate statistics
                sum_values += np.sum(raw_data, axis=1, dtype=np.float64)
                sum_squared += np.sum(raw_data.astype(np.float64) ** 2, axis=1)
                total_count += raw_data.shape[1]
                item_count += 1
    
    # Compute final statistics with numerical stability
    if total_count > 0:
        means = sum_values / total_count
        # More numerically stable variance calculation
        variance = (sum_squared / total_count) - (means ** 2)
        variance = np.maximum(variance, 0.0)  # Handle floating point errors
        stds = np.sqrt(variance)
        
        info = {
            'item_count': item_count,
            'total_count': total_count,
            'num_channels': num_channels
        }
        
        return means, stds, info
    else:
        return None, None, {'item_count': 0, 'total_count': 0, 'num_channels': None}


def calculate_norm_params_batch(
    datasets: List,
    target_labels: Optional[List[int]] = None,
    target_series: Optional[List[str]] = None,
    show_progress: bool = True
) -> dict:
    """
    Calculate normalization parameters for multiple label/series combinations.
    
    Args:
        datasets: List of dataset objects
        target_labels: List of target labels (None = all labels)
        target_series: List of target series (None = all series)
        show_progress: Whether to show progress bars
    
    Returns:
        Dictionary with keys like (label, serie) -> {'mean': array, 'std': array, 'info': dict}
    """
    results = {}
    
    # Collect all unique labels and series if not specified
    if target_labels is None or target_series is None:
        seen_labels = set()
        seen_series = set()
        for dataset in datasets:
            for item in dataset:
                seen_labels.add(item['label'])
                seen_series.add(item['serie'])
        
        if target_labels is None:
            target_labels = sorted(seen_labels)
        if target_series is None:
            target_series = sorted(seen_series)
    
    # Calculate for each combination
    for label in target_labels:
        for serie in target_series:
            means, stds, info = calculate_norm_params(
                datasets=datasets,
                target_label=label,
                target_serie=serie,
                show_progress=show_progress
            )
            results[(label, serie)] = {
                'mean': means,
                'std': stds,
                'info': info
            }
    
    return results


In [9]:
calculate_norm_params_batch(
    datasets=[train_set, val_set, test_set],
    target_labels=[6],
    target_series=['B', 'C', 'D', 'E', 'F']
)

                                                                     

{(6, 'B'): {'mean': array([0.58009147, 2.13962315, 0.12031674]),
  'std': array([4.78606225, 6.49078199, 7.51684796]),
  'info': {'item_count': 305, 'total_count': 45750000, 'num_channels': 3}},
 (6, 'C'): {'mean': array([1.1293965 , 2.04018865, 0.27358819]),
  'std': array([4.62223859, 7.79696241, 9.22351764]),
  'info': {'item_count': 311, 'total_count': 46650000, 'num_channels': 3}},
 (6, 'D'): {'mean': array([0.22941354, 1.4716678 , 0.6410115 ]),
  'std': array([4.81425054, 7.36753723, 9.13285726]),
  'info': {'item_count': 306, 'total_count': 45900000, 'num_channels': 3}},
 (6, 'E'): {'mean': array([0.11833483, 1.12777588, 0.30956035]),
  'std': array([4.89143049, 6.43525229, 9.54307939]),
  'info': {'item_count': 310, 'total_count': 46500000, 'num_channels': 3}},
 (6, 'F'): {'mean': array([0.53612407, 2.11247939, 0.30080982]),
  'std': array([4.88715638, 6.36863377, 8.54823905]),
  'info': {'item_count': 329, 'total_count': 49350000, 'num_channels': 3}}}

{(6, 'B'): {'mean': array([0.58009147, 2.13962315, 0.12031674]),
  'std': array([4.78606225, 6.49078199, 7.51684796]),
  'info': {'item_count': 305, 'total_count': 45750000, 'num_channels': 3}},
 (6, 'C'): {'mean': array([1.1293965 , 2.04018865, 0.27358819]),
  'std': array([4.62223859, 7.79696241, 9.22351764]),
  'info': {'item_count': 311, 'total_count': 46650000, 'num_channels': 3}},
 (6, 'D'): {'mean': array([0.22941354, 1.4716678 , 0.6410115 ]),
  'std': array([4.81425054, 7.36753723, 9.13285726]),
  'info': {'item_count': 306, 'total_count': 45900000, 'num_channels': 3}},
 (6, 'E'): {'mean': array([0.11833483, 1.12777588, 0.30956035]),
  'std': array([4.89143049, 6.43525229, 9.54307939]),
  'info': {'item_count': 310, 'total_count': 46500000, 'num_channels': 3}},
 (6, 'F'): {'mean': array([0.53612407, 2.11247939, 0.30080982]),
  'std': array([4.88715638, 6.36863377, 8.54823905]),
  'info': {'item_count': 329, 'total_count': 49350000, 'num_channels': 3}}}

In [None]:
# # Memory-efficient computation of statistics using running sums
# # This avoids storing all data in RAM at once

# target_label = 6
# target_serie = 'B'
# num_channels = None  # Will be determined from first item
# sum_values = None    # Running sum per channel
# sum_squared = None   # Running sum of squares per channel
# total_count = 0      # Total number of time steps across all items
# item_count = 0       # Number of items processed

# # Iterate through all datasets and accumulate statistics
# for dataset in [train_set, val_set, test_set]:
#     for item in dataset:
#         if item['label'] == target_label and item['serie'] == target_serie:
#             # item['raw'] has shape (channels, time_steps)
#             raw_data = item['raw']  # Shape: (channels, time_steps)
            
#             # Initialize accumulators on first item
#             if num_channels is None:
#                 num_channels = raw_data.shape[0]
#                 sum_values = np.zeros(num_channels)
#                 sum_squared = np.zeros(num_channels)
            
#             # Flatten along time_steps axis and accumulate
#             # For each channel, sum all time_steps
#             sum_values += np.sum(raw_data, axis=1)  # Sum across time_steps for each channel
#             sum_squared += np.sum(raw_data ** 2, axis=1)  # Sum of squares
#             total_count += raw_data.shape[1]  # Number of time_steps
#             item_count += 1

# # Compute final statistics
# if total_count > 0:
#     means = sum_values / total_count  # Mean per channel
#     # Standard deviation: sqrt(E[X^2] - E[X]^2)
#     stds = np.sqrt((sum_squared / total_count) - (means ** 2))
    
#     print(f"Found {item_count} items with label {target_label}")
#     print(f"Total time steps: {total_count}")
#     print(f"Means per channel: {means}")
#     print(f"Stds per channel: {stds}")
# else:
#     print(f"No items found with label {target_label}")
#     means, stds = None, None

    


Found 305 items with label 6 - B (305 items)
Total time steps: 45750000
Means per channel: [0.58009146 2.13962311 0.12031676]
Stds per channel: [4.78606218 6.49078193 7.51684789]