# 15. Data Augmentation

Data augmentation strategies for improving model generalization on sensor time series.

## Contents
1. [Setup](#1-setup)
2. [Time-Domain Augmentations](#2-time-domain-augmentations)
3. [Magnitude Augmentations](#3-magnitude-augmentations)
4. [Mixup and CutMix](#4-mixup-and-cutmix)
5. [Sensor-Specific Augmentations](#5-sensor-specific-augmentations)
6. [Augmentation Pipeline](#6-augmentation-pipeline)
7. [Evaluation with Augmentation](#7-evaluation-with-augmentation)

---

## 1. Setup

In [None]:
import sys
from pathlib import Path

# Add src to path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root / 'src'))

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, List, Tuple, Optional, Callable, Union
import json
from tqdm.notebook import tqdm
from scipy import interpolate, signal
from dataclasses import dataclass
import random

# Environment check
print(f"Python: {sys.version}")
print(f"PyTorch: {torch.__version__}")

# Reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

# Plotting style
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.figsize'] = (14, 6)

In [None]:
# Load sample data for visualization
DATA_DIR = project_root / 'outputs' / 'processed_v2'
test_path = DATA_DIR / 'test.pt'

if test_path.exists():
    test_data = torch.load(test_path, weights_only=False)
    sample_continuous = torch.tensor(test_data['continuous'][:10], dtype=torch.float32)
    sample_categorical = torch.tensor(test_data['categorical'][:10], dtype=torch.long)
    print(f"Sample data: {sample_continuous.shape}")
else:
    # Generate synthetic sample
    print("Using synthetic sample data")
    sample_continuous = torch.randn(10, 64, 155)
    sample_categorical = torch.randint(0, 10, (10, 64, 4))

# Single sample for visualization
vis_sample = sample_continuous[0].clone()

## 2. Time-Domain Augmentations

Augmentations that modify the temporal structure of the signal.

In [None]:
class TimeWarping:
    """Non-linear time warping augmentation."""
    
    def __init__(self, sigma=0.2, knot_density=4):
        self.sigma = sigma
        self.knot_density = knot_density
        
    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        """Apply time warping.
        
        Args:
            x: Input tensor of shape [T, C] or [B, T, C]
        """
        if x.dim() == 2:
            x = x.unsqueeze(0)
            squeeze = True
        else:
            squeeze = False
            
        B, T, C = x.shape
        
        # Create warping curve using random knots
        num_knots = max(2, T // self.knot_density)
        knot_positions = np.linspace(0, T - 1, num_knots)
        
        warped_batch = []
        for b in range(B):
            # Random warping at knot positions
            warp_values = knot_positions + np.random.randn(num_knots) * self.sigma * T / num_knots
            warp_values = np.clip(warp_values, 0, T - 1)
            warp_values = np.sort(warp_values)  # Ensure monotonicity
            
            # Interpolate to get full warping function
            warp_fn = interpolate.interp1d(knot_positions, warp_values, kind='cubic', fill_value='extrapolate')
            warped_indices = warp_fn(np.arange(T))
            warped_indices = np.clip(warped_indices, 0, T - 1)
            
            # Apply warping by interpolation
            x_np = x[b].numpy()
            warped = np.zeros_like(x_np)
            for c in range(C):
                interp_fn = interpolate.interp1d(np.arange(T), x_np[:, c], kind='linear', fill_value='extrapolate')
                warped[:, c] = interp_fn(warped_indices)
            
            warped_batch.append(torch.tensor(warped, dtype=x.dtype))
        
        result = torch.stack(warped_batch)
        return result.squeeze(0) if squeeze else result

# Test time warping
time_warp = TimeWarping(sigma=0.1)
warped_sample = time_warp(vis_sample)

# Visualize
fig, axes = plt.subplots(1, 2, figsize=(14, 4))

channel = 0
axes[0].plot(vis_sample[:, channel].numpy(), label='Original', alpha=0.8)
axes[0].plot(warped_sample[:, channel].numpy(), label='Time Warped', alpha=0.8)
axes[0].set_xlabel('Time Step')
axes[0].set_ylabel('Value')
axes[0].set_title('Time Warping Effect (Channel 0)')
axes[0].legend()

# Show warping on multiple channels
for c in range(5):
    offset = c * 2
    axes[1].plot(vis_sample[:, c].numpy() + offset, 'b-', alpha=0.5)
    axes[1].plot(warped_sample[:, c].numpy() + offset, 'r-', alpha=0.5)
axes[1].set_xlabel('Time Step')
axes[1].set_title('Time Warping: Original (blue) vs Warped (red)')

plt.tight_layout()
plt.show()

In [None]:
class WindowSlicing:
    """Extract random time windows with optional resampling."""
    
    def __init__(self, crop_ratio=0.9, resample=True):
        self.crop_ratio = crop_ratio
        self.resample = resample
        
    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        if x.dim() == 2:
            x = x.unsqueeze(0)
            squeeze = True
        else:
            squeeze = False
            
        B, T, C = x.shape
        crop_len = int(T * self.crop_ratio)
        
        result = []
        for b in range(B):
            start = random.randint(0, T - crop_len)
            cropped = x[b, start:start + crop_len, :]
            
            if self.resample:
                # Resample back to original length
                cropped_np = cropped.numpy()
                resampled = signal.resample(cropped_np, T, axis=0)
                result.append(torch.tensor(resampled, dtype=x.dtype))
            else:
                # Pad to original length
                padded = F.pad(cropped.T, (0, T - crop_len)).T
                result.append(padded)
        
        output = torch.stack(result)
        return output.squeeze(0) if squeeze else output


class RandomCrop:
    """Random cropping with reflection padding."""
    
    def __init__(self, crop_fraction=0.1):
        self.crop_fraction = crop_fraction
        
    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        T = x.shape[-2]
        crop_amount = int(T * self.crop_fraction)
        
        start_crop = random.randint(0, crop_amount)
        end_crop = random.randint(0, crop_amount)
        
        cropped = x[..., start_crop:T-end_crop, :]
        
        # Pad back with reflection
        if cropped.dim() == 2:
            cropped = cropped.unsqueeze(0)
            padded = F.pad(cropped.permute(0, 2, 1), (start_crop, end_crop), mode='reflect').permute(0, 2, 1)
            return padded.squeeze(0)
        else:
            padded = F.pad(cropped.permute(0, 2, 1), (start_crop, end_crop), mode='reflect').permute(0, 2, 1)
            return padded


class TimeShift:
    """Circular shift in time."""
    
    def __init__(self, max_shift_ratio=0.1):
        self.max_shift_ratio = max_shift_ratio
        
    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        T = x.shape[-2]
        max_shift = int(T * self.max_shift_ratio)
        shift = random.randint(-max_shift, max_shift)
        return torch.roll(x, shifts=shift, dims=-2)


# Test and visualize
augmentations = {
    'Window Slicing': WindowSlicing(crop_ratio=0.8),
    'Random Crop': RandomCrop(crop_fraction=0.15),
    'Time Shift': TimeShift(max_shift_ratio=0.1),
}

fig, axes = plt.subplots(2, 2, figsize=(14, 8))

channel = 0
axes[0, 0].plot(vis_sample[:, channel].numpy())
axes[0, 0].set_title('Original')
axes[0, 0].set_xlabel('Time')

for ax, (name, aug) in zip(axes.flat[1:], augmentations.items()):
    augmented = aug(vis_sample)
    ax.plot(augmented[:, channel].numpy())
    ax.set_title(name)
    ax.set_xlabel('Time')

plt.tight_layout()
plt.savefig(project_root / 'reports' / 'time_augmentations.png', dpi=150, bbox_inches='tight')
plt.show()

## 3. Magnitude Augmentations

Augmentations that modify signal amplitude and values.

In [None]:
class GaussianNoise:
    """Add Gaussian noise."""
    
    def __init__(self, std=0.1):
        self.std = std
        
    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        noise = torch.randn_like(x) * self.std
        return x + noise


class Scaling:
    """Random scaling of magnitude."""
    
    def __init__(self, sigma=0.1):
        self.sigma = sigma
        
    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        # Random scale factor per channel
        scale = 1 + torch.randn(x.shape[-1]) * self.sigma
        return x * scale


class MagnitudeWarping:
    """Smooth, random magnitude warping."""
    
    def __init__(self, sigma=0.2, knot_density=4):
        self.sigma = sigma
        self.knot_density = knot_density
        
    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        if x.dim() == 2:
            T, C = x.shape
            squeeze = True
            x = x.unsqueeze(0)
        else:
            B, T, C = x.shape
            squeeze = False
        
        B = x.shape[0]
        num_knots = max(2, T // self.knot_density)
        
        result = []
        for b in range(B):
            knot_positions = np.linspace(0, T - 1, num_knots)
            warp_values = 1 + np.random.randn(num_knots) * self.sigma
            
            warp_fn = interpolate.interp1d(knot_positions, warp_values, kind='cubic', fill_value='extrapolate')
            warp_curve = torch.tensor(warp_fn(np.arange(T)), dtype=x.dtype).unsqueeze(-1)
            
            result.append(x[b] * warp_curve)
        
        output = torch.stack(result)
        return output.squeeze(0) if squeeze else output


class ChannelDropout:
    """Randomly drop sensor channels."""
    
    def __init__(self, p=0.1):
        self.p = p
        
    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        mask = torch.rand(x.shape[-1]) > self.p
        return x * mask.float()


class ChannelShuffle:
    """Shuffle a subset of channels."""
    
    def __init__(self, shuffle_ratio=0.1):
        self.shuffle_ratio = shuffle_ratio
        
    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        C = x.shape[-1]
        num_shuffle = int(C * self.shuffle_ratio)
        
        if num_shuffle < 2:
            return x
        
        # Select channels to shuffle
        indices = torch.randperm(C)[:num_shuffle]
        shuffle_order = indices[torch.randperm(num_shuffle)]
        
        result = x.clone()
        result[..., indices] = x[..., shuffle_order]
        return result


# Visualize magnitude augmentations
mag_augmentations = {
    'Gaussian Noise': GaussianNoise(std=0.2),
    'Scaling': Scaling(sigma=0.2),
    'Magnitude Warp': MagnitudeWarping(sigma=0.2),
    'Channel Dropout': ChannelDropout(p=0.3),
}

fig, axes = plt.subplots(2, 3, figsize=(15, 8))

channel = 0
axes[0, 0].plot(vis_sample[:, channel].numpy())
axes[0, 0].set_title('Original')

for ax, (name, aug) in zip(axes.flat[1:5], mag_augmentations.items()):
    augmented = aug(vis_sample.clone())
    ax.plot(augmented[:, channel].numpy())
    ax.set_title(name)

# Remove empty subplot
axes[1, 2].axis('off')

plt.tight_layout()
plt.savefig(project_root / 'reports' / 'magnitude_augmentations.png', dpi=150, bbox_inches='tight')
plt.show()

## 4. Mixup and CutMix

Advanced mixing strategies for regularization.

In [None]:
class Mixup:
    """Mixup augmentation for time series."""
    
    def __init__(self, alpha=0.4):
        self.alpha = alpha
        
    def __call__(self, x1: torch.Tensor, x2: torch.Tensor, 
                 y1: Optional[torch.Tensor] = None, y2: Optional[torch.Tensor] = None
                ) -> Tuple[torch.Tensor, Optional[torch.Tensor], float]:
        """Mix two samples.
        
        Returns:
            mixed_x, mixed_y (if targets provided), lambda value
        """
        lam = np.random.beta(self.alpha, self.alpha)
        mixed_x = lam * x1 + (1 - lam) * x2
        
        if y1 is not None and y2 is not None:
            mixed_y = lam * y1 + (1 - lam) * y2
            return mixed_x, mixed_y, lam
        
        return mixed_x, None, lam


class CutMix:
    """CutMix: Cut and paste between samples."""
    
    def __init__(self, alpha=1.0):
        self.alpha = alpha
        
    def __call__(self, x1: torch.Tensor, x2: torch.Tensor,
                 y1: Optional[torch.Tensor] = None, y2: Optional[torch.Tensor] = None
                ) -> Tuple[torch.Tensor, Optional[torch.Tensor], float]:
        """Cut and mix two samples along time dimension."""
        T = x1.shape[-2]
        lam = np.random.beta(self.alpha, self.alpha)
        
        cut_len = int(T * (1 - lam))
        cut_start = random.randint(0, T - cut_len)
        
        mixed = x1.clone()
        mixed[..., cut_start:cut_start + cut_len, :] = x2[..., cut_start:cut_start + cut_len, :]
        
        # Actual lambda after cutting
        actual_lam = 1 - cut_len / T
        
        if y1 is not None and y2 is not None:
            mixed_y = actual_lam * y1 + (1 - actual_lam) * y2
            return mixed, mixed_y, actual_lam
        
        return mixed, None, actual_lam


class TemporalCutMix:
    """CutMix with multiple random temporal segments."""
    
    def __init__(self, num_segments=3, alpha=1.0):
        self.num_segments = num_segments
        self.alpha = alpha
        
    def __call__(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
        T = x1.shape[-2]
        mixed = x1.clone()
        
        for _ in range(self.num_segments):
            if random.random() < 0.5:
                seg_len = random.randint(1, T // (self.num_segments * 2))
                start = random.randint(0, T - seg_len)
                mixed[..., start:start + seg_len, :] = x2[..., start:start + seg_len, :]
        
        return mixed


# Visualize mixing strategies
sample1 = sample_continuous[0]
sample2 = sample_continuous[1]

mixup = Mixup(alpha=0.4)
cutmix = CutMix(alpha=1.0)
temporal_cutmix = TemporalCutMix(num_segments=3)

mixed_mixup, _, lam1 = mixup(sample1, sample2)
mixed_cutmix, _, lam2 = cutmix(sample1, sample2)
mixed_temporal = temporal_cutmix(sample1, sample2)

fig, axes = plt.subplots(2, 3, figsize=(15, 8))

channel = 0
axes[0, 0].plot(sample1[:, channel].numpy(), label='Sample 1')
axes[0, 0].plot(sample2[:, channel].numpy(), label='Sample 2', alpha=0.7)
axes[0, 0].set_title('Original Samples')
axes[0, 0].legend()

axes[0, 1].plot(mixed_mixup[:, channel].numpy())
axes[0, 1].set_title(f'Mixup (λ={lam1:.2f})')

axes[0, 2].plot(mixed_cutmix[:, channel].numpy())
axes[0, 2].set_title(f'CutMix (λ={lam2:.2f})')

axes[1, 0].plot(mixed_temporal[:, channel].numpy())
axes[1, 0].set_title('Temporal CutMix')

# Show difference
axes[1, 1].fill_between(range(len(sample1)), 
                        (mixed_cutmix[:, channel] - sample1[:, channel]).numpy(),
                        alpha=0.5, label='Difference')
axes[1, 1].set_title('CutMix Difference from Sample 1')
axes[1, 1].axhline(y=0, color='k', linestyle='--', alpha=0.3)

axes[1, 2].axis('off')

plt.tight_layout()
plt.savefig(project_root / 'reports' / 'mixing_augmentations.png', dpi=150, bbox_inches='tight')
plt.show()

## 5. Sensor-Specific Augmentations

Augmentations designed for CNC sensor data characteristics.

In [None]:
class SensorGroupMask:
    """Mask entire sensor groups to simulate sensor failures."""
    
    def __init__(self, sensor_groups: Dict[str, slice], mask_prob=0.1):
        self.sensor_groups = sensor_groups
        self.mask_prob = mask_prob
        
    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        result = x.clone()
        for group_name, group_slice in self.sensor_groups.items():
            if random.random() < self.mask_prob:
                result[..., group_slice] = 0
        return result


class AxisPermutation:
    """Randomly permute sensor axes (e.g., swap X, Y, Z)."""
    
    def __init__(self, axis_groups: List[List[int]], permute_prob=0.5):
        self.axis_groups = axis_groups
        self.permute_prob = permute_prob
        
    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        result = x.clone()
        
        for group in self.axis_groups:
            if random.random() < self.permute_prob:
                perm = torch.randperm(len(group))
                permuted_indices = [group[p] for p in perm]
                result[..., group] = x[..., permuted_indices]
        
        return result


class MotionProfilePerturbation:
    """Perturb motion profiles while maintaining physical plausibility."""
    
    def __init__(self, velocity_noise=0.05, acceleration_jitter=0.1):
        self.velocity_noise = velocity_noise
        self.acceleration_jitter = acceleration_jitter
        
    def __call__(self, x: torch.Tensor, velocity_channels: slice = slice(30, 60),
                 accel_channels: slice = slice(60, 90)) -> torch.Tensor:
        result = x.clone()
        
        # Add noise to velocity
        vel_noise = torch.randn_like(result[..., velocity_channels]) * self.velocity_noise
        result[..., velocity_channels] += vel_noise
        
        # Add jitter to acceleration
        accel_jitter = torch.randn_like(result[..., accel_channels]) * self.acceleration_jitter
        result[..., accel_channels] += accel_jitter
        
        return result


class SensorDrift:
    """Simulate gradual sensor drift over time."""
    
    def __init__(self, drift_rate=0.01, channel_prob=0.2):
        self.drift_rate = drift_rate
        self.channel_prob = channel_prob
        
    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        T, C = x.shape[-2], x.shape[-1]
        result = x.clone()
        
        # Select channels to apply drift
        drift_channels = torch.rand(C) < self.channel_prob
        
        # Create linear drift
        t = torch.linspace(0, 1, T).unsqueeze(-1)
        drift = t * (torch.rand(C) * 2 - 1) * self.drift_rate
        drift[:, ~drift_channels] = 0
        
        if result.dim() == 3:
            drift = drift.unsqueeze(0)
        
        return result + drift


# Define sensor groups (hypothetical)
sensor_groups = {
    'position': slice(0, 30),
    'velocity': slice(30, 60),
    'acceleration': slice(60, 90),
    'force': slice(90, 120),
    'temperature': slice(120, 140),
    'misc': slice(140, 155)
}

# Test sensor-specific augmentations
sensor_augs = {
    'Sensor Mask': SensorGroupMask(sensor_groups, mask_prob=0.3),
    'Motion Perturb': MotionProfilePerturbation(velocity_noise=0.1),
    'Sensor Drift': SensorDrift(drift_rate=0.05, channel_prob=0.3),
}

fig, axes = plt.subplots(2, 2, figsize=(14, 8))

channel = 35  # Velocity channel
axes[0, 0].plot(vis_sample[:, channel].numpy())
axes[0, 0].set_title(f'Original (Channel {channel})')

for ax, (name, aug) in zip(axes.flat[1:], sensor_augs.items()):
    augmented = aug(vis_sample.clone())
    ax.plot(augmented[:, channel].numpy())
    ax.set_title(name)

plt.tight_layout()
plt.savefig(project_root / 'reports' / 'sensor_augmentations.png', dpi=150, bbox_inches='tight')
plt.show()

## 6. Augmentation Pipeline

Combine augmentations into a configurable pipeline.

In [None]:
@dataclass
class AugmentationConfig:
    """Configuration for augmentation pipeline."""
    time_warp_prob: float = 0.3
    time_warp_sigma: float = 0.1
    noise_prob: float = 0.5
    noise_std: float = 0.1
    scaling_prob: float = 0.3
    scaling_sigma: float = 0.1
    magnitude_warp_prob: float = 0.3
    channel_dropout_prob: float = 0.2
    channel_dropout_rate: float = 0.1
    time_shift_prob: float = 0.2
    mixup_prob: float = 0.0  # Applied at batch level
    cutmix_prob: float = 0.0


class AugmentationPipeline:
    """Composable augmentation pipeline."""
    
    def __init__(self, config: AugmentationConfig):
        self.config = config
        
        # Initialize augmentations
        self.augmentations = [
            (config.time_warp_prob, TimeWarping(sigma=config.time_warp_sigma)),
            (config.noise_prob, GaussianNoise(std=config.noise_std)),
            (config.scaling_prob, Scaling(sigma=config.scaling_sigma)),
            (config.magnitude_warp_prob, MagnitudeWarping(sigma=0.2)),
            (config.channel_dropout_prob, ChannelDropout(p=config.channel_dropout_rate)),
            (config.time_shift_prob, TimeShift(max_shift_ratio=0.1)),
        ]
        
        self.mixup = Mixup(alpha=0.4) if config.mixup_prob > 0 else None
        self.cutmix = CutMix(alpha=1.0) if config.cutmix_prob > 0 else None
        
    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        """Apply random augmentations."""
        result = x.clone()
        
        for prob, aug in self.augmentations:
            if random.random() < prob:
                result = aug(result)
        
        return result
    
    def augment_batch(self, batch: torch.Tensor, 
                      targets: Optional[torch.Tensor] = None
                     ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """Augment a batch with mixing strategies."""
        B = batch.shape[0]
        
        # Apply per-sample augmentations
        augmented = torch.stack([self(batch[i]) for i in range(B)])
        
        # Apply mixup/cutmix
        if self.mixup and random.random() < self.config.mixup_prob:
            indices = torch.randperm(B)
            augmented, targets, _ = self.mixup(
                augmented, augmented[indices], 
                targets, targets[indices] if targets is not None else None
            )
        elif self.cutmix and random.random() < self.config.cutmix_prob:
            indices = torch.randperm(B)
            augmented, targets, _ = self.cutmix(
                augmented, augmented[indices],
                targets, targets[indices] if targets is not None else None
            )
        
        return augmented, targets


# Create pipeline with default config
config = AugmentationConfig()
pipeline = AugmentationPipeline(config)

# Apply to batch
augmented_batch, _ = pipeline.augment_batch(sample_continuous)

print(f"Original batch shape: {sample_continuous.shape}")
print(f"Augmented batch shape: {augmented_batch.shape}")

In [None]:
# Visualize multiple augmented versions of same sample
fig, axes = plt.subplots(3, 3, figsize=(15, 10))

channel = 0
axes[0, 0].plot(vis_sample[:, channel].numpy())
axes[0, 0].set_title('Original')

for ax in axes.flat[1:]:
    augmented = pipeline(vis_sample)
    ax.plot(augmented[:, channel].numpy())
    ax.set_title('Augmented')

plt.suptitle('Multiple Random Augmentations of Same Sample', y=1.02)
plt.tight_layout()
plt.savefig(project_root / 'reports' / 'augmentation_pipeline.png', dpi=150, bbox_inches='tight')
plt.show()

## 7. Evaluation with Augmentation

Test how augmentation affects model training.

In [None]:
class AugmentedDataset(torch.utils.data.Dataset):
    """Dataset wrapper with online augmentation."""
    
    def __init__(self, continuous, categorical, targets=None, 
                 augmentation_pipeline=None, training=True):
        self.continuous = continuous
        self.categorical = categorical
        self.targets = targets
        self.pipeline = augmentation_pipeline
        self.training = training
        
    def __len__(self):
        return len(self.continuous)
    
    def __getitem__(self, idx):
        x = self.continuous[idx]
        cat = self.categorical[idx]
        
        if self.training and self.pipeline:
            x = self.pipeline(x)
        
        if self.targets is not None:
            return x, cat, self.targets[idx]
        return x, cat


# Create augmented dataset
augmented_dataset = AugmentedDataset(
    sample_continuous,
    sample_categorical,
    augmentation_pipeline=pipeline,
    training=True
)

print(f"Dataset size: {len(augmented_dataset)}")
print(f"Sample shape: {augmented_dataset[0][0].shape}")

In [None]:
# Augmentation strength analysis
def analyze_augmentation_strength(pipeline, sample, n_augmentations=100):
    """Analyze how much augmentation changes the data."""
    augmented_samples = torch.stack([pipeline(sample) for _ in range(n_augmentations)])
    
    # Compute statistics
    mean_diff = (augmented_samples - sample.unsqueeze(0)).abs().mean(dim=0)
    std_across_augs = augmented_samples.std(dim=0)
    
    return {
        'mean_abs_diff': mean_diff.mean().item(),
        'max_abs_diff': mean_diff.max().item(),
        'augmentation_std': std_across_augs.mean().item(),
        'per_channel_diff': mean_diff.mean(dim=0)
    }

# Analyze default pipeline
stats = analyze_augmentation_strength(pipeline, vis_sample)

print("Augmentation Strength Analysis:")
print(f"  Mean absolute difference: {stats['mean_abs_diff']:.4f}")
print(f"  Max absolute difference: {stats['max_abs_diff']:.4f}")
print(f"  Augmentation std: {stats['augmentation_std']:.4f}")

In [None]:
# Compare different augmentation configurations
configs = {
    'Light': AugmentationConfig(
        noise_prob=0.3, noise_std=0.05,
        scaling_prob=0.2, scaling_sigma=0.05,
        time_warp_prob=0.1, time_warp_sigma=0.05
    ),
    'Medium': AugmentationConfig(
        noise_prob=0.5, noise_std=0.1,
        scaling_prob=0.3, scaling_sigma=0.1,
        time_warp_prob=0.3, time_warp_sigma=0.1
    ),
    'Heavy': AugmentationConfig(
        noise_prob=0.7, noise_std=0.2,
        scaling_prob=0.5, scaling_sigma=0.2,
        time_warp_prob=0.5, time_warp_sigma=0.2,
        channel_dropout_prob=0.3
    ),
}

strength_comparison = {}
for name, cfg in configs.items():
    pipe = AugmentationPipeline(cfg)
    stats = analyze_augmentation_strength(pipe, vis_sample, n_augmentations=50)
    strength_comparison[name] = stats
    print(f"{name}: mean_diff={stats['mean_abs_diff']:.4f}, std={stats['augmentation_std']:.4f}")

In [None]:
# Save augmentation configuration
aug_config_dict = {
    'light': {
        'noise_prob': 0.3, 'noise_std': 0.05,
        'scaling_prob': 0.2, 'scaling_sigma': 0.05,
        'time_warp_prob': 0.1, 'time_warp_sigma': 0.05
    },
    'medium': {
        'noise_prob': 0.5, 'noise_std': 0.1,
        'scaling_prob': 0.3, 'scaling_sigma': 0.1,
        'time_warp_prob': 0.3, 'time_warp_sigma': 0.1
    },
    'heavy': {
        'noise_prob': 0.7, 'noise_std': 0.2,
        'scaling_prob': 0.5, 'scaling_sigma': 0.2,
        'time_warp_prob': 0.5, 'time_warp_sigma': 0.2,
        'channel_dropout_prob': 0.3
    }
}

config_path = project_root / 'configs' / 'augmentation_configs.json'
config_path.parent.mkdir(exist_ok=True)
with open(config_path, 'w') as f:
    json.dump(aug_config_dict, f, indent=2)

print(f"Augmentation configs saved to: {config_path}")

---

## Summary

This notebook provides comprehensive data augmentation for sensor time series:

1. **Time-Domain**: Time warping, window slicing, random crop, time shift
2. **Magnitude**: Gaussian noise, scaling, magnitude warping, channel dropout
3. **Mixing**: Mixup, CutMix, Temporal CutMix for regularization
4. **Sensor-Specific**: Sensor group masking, axis permutation, drift simulation
5. **Pipeline**: Configurable, composable augmentation pipeline
6. **Evaluation**: Strength analysis and configuration comparison

---

**Navigation:**
← [Previous: 14_robustness_testing](14_robustness_testing.ipynb) |
[Next: 16_architecture_comparison](16_architecture_comparison.ipynb) →