# Data Analysis Notebook

This notebook is exploratory only.

For benchmark runs and reproducible comparisons, use `experiments/manifest_repro_v1.json` with `scripts/run_experiment_manifest.py`.

## Configuration

**Important:** All experiments use **`n_mels=64`** for consistency and comparability across all model variants (TinyCNN, CRNN, CBAM, KD).

This notebook analyzes the cached spectrograms. If the cache was regenerated with different parameters, the outputs below may differ from the documented values. Always regenerate the cache with `python3 scripts/preprocess.py` to ensure consistency.

In [None]:
import sys
sys.path.insert(0, '..')

import numpy as np
import matplotlib.pyplot as plt
import h5py
from collections import Counter

In [None]:
# Load cached spectrograms
CACHE_PATH = '../results/cache/spectrograms/train.h5'
LABELS = ['quiet', 'breathe', 'snore']

h5 = h5py.File(CACHE_PATH, 'r')
specs = h5['spectrograms']
labels = h5['labels']

print(f"Dataset shape: {specs.shape}")
print(f"n_mels: {h5.attrs['n_mels']}")
print(f"max_time: {h5.attrs['max_time']}")
print(f"sample_rate: {h5.attrs['sample_rate']}")

## Class Distribution

In [None]:
# Count samples per class
label_counts = Counter(labels[:])

fig, ax = plt.subplots(figsize=(8, 5))
classes = [LABELS[i] for i in range(3)]
counts = [label_counts[i] for i in range(3)]

bars = ax.bar(classes, counts, color=['#2ecc71', '#3498db', '#9b59b6'])
ax.set_ylabel('Number of Samples')
ax.set_title('Class Distribution (Train Set)')

for bar, count in zip(bars, counts):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 500, 
            f'{count:,}', ha='center', va='bottom', fontsize=10)

plt.tight_layout()
plt.show()

print("\nClass counts:")
for i, name in enumerate(LABELS):
    pct = label_counts[i] / len(labels) * 100
    print(f"  {name}: {label_counts[i]:,} ({pct:.1f}%)")

## Sample Spectrograms per Class

In [None]:
def get_samples_by_class(n_per_class=3):
    """Get sample indices for each class."""
    samples = {i: [] for i in range(3)}
    
    for idx in range(len(labels)):
        label = int(labels[idx])
        if len(samples[label]) < n_per_class:
            samples[label].append(idx)
        if all(len(v) >= n_per_class for v in samples.values()):
            break
    
    return samples

samples = get_samples_by_class(3)
print("Sample indices per class:")
for i, name in enumerate(LABELS):
    print(f"  {name}: {samples[i]}")

In [None]:
# Plot spectrograms for each class
fig, axes = plt.subplots(3, 3, figsize=(14, 8))

for row, (class_id, indices) in enumerate(samples.items()):
    for col, idx in enumerate(indices):
        spec = specs[idx]
        ax = axes[row, col]
        
        im = ax.imshow(spec, aspect='auto', origin='lower', cmap='viridis')
        
        if col == 0:
            ax.set_ylabel(LABELS[class_id], fontsize=12, fontweight='bold')
        if row == 0:
            ax.set_title(f'Sample {col+1}', fontsize=10)
        if row == 2:
            ax.set_xlabel('Time frames')
        
        ax.set_yticks([0, 64, 127])
        ax.set_yticklabels(['0', '64', '128'])

plt.suptitle('Spectrogram Examples by Class', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## Spectrogram Statistics

In [None]:
# Compute stats for random samples
np.random.seed(42)
sample_indices = np.random.choice(len(specs), min(1000, len(specs)), replace=False)

stats_by_class = {i: {'means': [], 'stds': [], 'mins': [], 'maxs': []} for i in range(3)}

for idx in sample_indices:
    spec = specs[idx]
    label = int(labels[idx])
    stats_by_class[label]['means'].append(spec.mean())
    stats_by_class[label]['stds'].append(spec.std())
    stats_by_class[label]['mins'].append(spec.min())
    stats_by_class[label]['maxs'].append(spec.max())

print("Spectrogram statistics by class (sampled):")
print("-" * 60)
for i, name in enumerate(LABELS):
    s = stats_by_class[i]
    print(f"{name}:")
    print(f"  Mean: {np.mean(s['means']):.2f} +/- {np.std(s['means']):.2f}")
    print(f"  Std:  {np.mean(s['stds']):.2f} +/- {np.std(s['stds']):.2f}")
    print(f"  Min:  {np.mean(s['mins']):.2f}, Max: {np.mean(s['maxs']):.2f}")
    print()

In [None]:
# Box plot of mean values by class
fig, ax = plt.subplots(figsize=(8, 5))

data = [stats_by_class[i]['means'] for i in range(3)]
bp = ax.boxplot(data, labels=LABELS, patch_artist=True)

colors = ['#2ecc71', '#3498db', '#9b59b6']
for patch, color in zip(bp['boxes'], colors):
    patch.set_facecolor(color)
    patch.set_alpha(0.7)

ax.set_ylabel('Mean Spectrogram Value (dB)')
ax.set_title('Distribution of Mean Spectrogram Values by Class')
plt.tight_layout()
plt.show()

## Average Spectrogram per Class

In [None]:
# Compute average spectrogram per class (using first 500 samples each)
avg_specs = {}
n_samples_avg = 500

for class_id in range(3):
    class_indices = []
    for idx in range(len(labels)):
        if int(labels[idx]) == class_id:
            class_indices.append(idx)
            if len(class_indices) >= n_samples_avg:
                break
    
    if class_indices:
        class_specs = np.array([specs[i] for i in class_indices])
        avg_specs[class_id] = class_specs.mean(axis=0)
        print(f"{LABELS[class_id]}: averaged {len(class_indices)} samples")

In [None]:
# Plot average spectrograms
fig, axes = plt.subplots(1, 3, figsize=(14, 4))

for i, (class_id, avg_spec) in enumerate(avg_specs.items()):
    ax = axes[i]
    # Use individual vmin/vmax for each class
    vmin, vmax = avg_spec.min(), avg_spec.max()
    im = ax.imshow(avg_spec, aspect='auto', origin='lower', cmap='viridis',
                   vmin=vmin, vmax=vmax)
    ax.set_title(LABELS[class_id], fontsize=12, fontweight='bold')
    ax.set_xlabel('Time frames')
    if i == 0:
        ax.set_ylabel('Mel bins')
    plt.colorbar(im, ax=ax, label='dB')

plt.suptitle('Average Spectrogram per Class', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

---

## Advanced Acoustic Analysis

The following sections provide deeper insights into acoustic characteristics that distinguish sleep events.

### 1. Temporal Energy Analysis

Analyze how acoustic energy evolves over time for each class. This reveals temporal patterns:
- **Quiet**: Minimal, stable energy
- **Breathe**: Periodic energy fluctuations
- **Snore**: High-energy bursts with varying patterns

In [None]:
# Compute temporal energy profiles (average dB over time)
print("Computing temporal energy profiles...")

temporal_profiles = {i: [] for i in range(3)}
n_samples_temporal = 200  # samples per class

for class_id in range(3):
    class_indices = [i for i in range(len(labels)) if int(labels[i]) == class_id][:n_samples_temporal]
    for idx in class_indices:
        spec = specs[idx]
        # Average energy across mel bins for each time frame
        temporal_energy = spec.mean(axis=0)  # [time_frames]
        temporal_profiles[class_id].append(temporal_energy)

# Average temporal profiles across samples
avg_temporal = {i: np.mean(temporal_profiles[i], axis=0) for i in range(3)}
std_temporal = {i: np.std(temporal_profiles[i], axis=0) for i in range(3)}

print(f"Computed temporal profiles for {n_samples_temporal} samples per class")

In [None]:
# Plot temporal energy profiles
fig, axes = plt.subplots(1, 3, figsize=(16, 4))
colors = ['#2ecc71', '#3498db', '#9b59b6']
time_axis = np.arange(avg_temporal[0].shape[0]) * (5.0 / avg_temporal[0].shape[0])  # 5-second clips

for i, (class_id, color) in enumerate(zip(range(3), colors)):
    ax = axes[i]
    mean_profile = avg_temporal[class_id]
    std_profile = std_temporal[class_id]
    
    ax.plot(time_axis, mean_profile, color=color, linewidth=2, label='Mean')
    ax.fill_between(time_axis, mean_profile - std_profile, mean_profile + std_profile, 
                     color=color, alpha=0.3, label='±1 std')
    
    ax.set_xlabel('Time (seconds)', fontsize=10)
    ax.set_ylabel('Mean Energy (dB)', fontsize=10)
    ax.set_title(f'{LABELS[class_id].capitalize()}', fontsize=12, fontweight='bold')
    ax.legend(loc='upper right', fontsize=8)
    ax.grid(alpha=0.3)

plt.suptitle('Temporal Energy Profiles', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

# Compute temporal variability (coefficient of variation)
print("\nTemporal Variability (CV = std/mean):")
for i, name in enumerate(LABELS):
    cv = np.std(avg_temporal[i]) / (np.abs(np.mean(avg_temporal[i])) + 1e-6)
    print(f"  {name}: {cv:.3f}")

### 2. Frequency Band Analysis

Compare energy distribution across frequency bands:
- **Low** (0-32 mel bins): ~0-2kHz - Contains most snoring energy
- **Mid** (33-96 mel bins): ~2-6kHz - Breathing and vocal tract resonances
- **High** (97-128 mel bins): ~6-8kHz - High-frequency components and noise

In [None]:
# Define frequency bands (mel bins)
low_band = slice(0, 32)      # Low frequencies
mid_band = slice(32, 96)     # Mid frequencies  
high_band = slice(96, 128)   # High frequencies

# Compute energy per band for each class
band_energy = {i: {'low': [], 'mid': [], 'high': []} for i in range(3)}
n_samples_bands = 500

for class_id in range(3):
    class_indices = [i for i in range(len(labels)) if int(labels[i]) == class_id][:n_samples_bands]
    for idx in class_indices:
        spec = specs[idx]
        # Compute mean energy in each band
        band_energy[class_id]['low'].append(spec[low_band, :].mean())
        band_energy[class_id]['mid'].append(spec[mid_band, :].mean())
        band_energy[class_id]['high'].append(spec[high_band, :].mean())

print(f"Computed frequency band energy for {n_samples_bands} samples per class")

In [None]:
# Plot frequency band energy distribution
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: Bar plot of mean energy per band
ax = axes[0]
x = np.arange(3)  # 3 frequency bands
width = 0.25
bands_list = ['low', 'mid', 'high']

for i, class_id in enumerate(range(3)):
    means = [np.mean(band_energy[class_id][band]) for band in bands_list]
    ax.bar(x + i*width, means, width, label=LABELS[class_id], color=colors[i], alpha=0.8)

ax.set_xlabel('Frequency Band', fontsize=11)
ax.set_ylabel('Mean Energy (dB)', fontsize=11)
ax.set_title('Energy Distribution Across Frequency Bands', fontsize=12, fontweight='bold')
ax.set_xticks(x + width)
ax.set_xticklabels(['Low\n(0-2kHz)', 'Mid\n(2-6kHz)', 'High\n(6-8kHz)'])
ax.legend()
ax.grid(axis='y', alpha=0.3)

# Right: Normalized energy (percentage per class)
ax = axes[1]
for i, class_id in enumerate(range(3)):
    total = sum(np.mean(band_energy[class_id][band]) for band in bands_list)
    percentages = [100 * np.mean(band_energy[class_id][band]) / total for band in bands_list]
    ax.bar(x + i*width, percentages, width, label=LABELS[class_id], color=colors[i], alpha=0.8)

ax.set_xlabel('Frequency Band', fontsize=11)
ax.set_ylabel('Energy Percentage (%)', fontsize=11)
ax.set_title('Normalized Energy Distribution', fontsize=12, fontweight='bold')
ax.set_xticks(x + width)
ax.set_xticklabels(['Low', 'Mid', 'High'])
ax.legend()
ax.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

# Print statistics
print("\nFrequency Band Energy Statistics:")
print("-" * 60)
for i, name in enumerate(LABELS):
    print(f"{name}:")
    for band in bands_list:
        mean_val = np.mean(band_energy[i][band])
        std_val = np.std(band_energy[i][band])
        print(f"  {band:5s}: {mean_val:6.2f} ± {std_val:5.2f} dB")
    print()

### 3. Spectral Centroid

The spectral centroid indicates the "center of mass" of the spectrum - where most energy is concentrated.  
Lower centroid = more low-frequency content (expected for snore).

In [None]:
# Compute spectral centroid for each class
centroids = {i: [] for i in range(3)}
n_samples_centroid = 500

for class_id in range(3):
    class_indices = [i for i in range(len(labels)) if int(labels[i]) == class_id][:n_samples_centroid]
    for idx in class_indices:
        spec = specs[idx]
        # Average across time to get frequency profile
        freq_profile = spec.mean(axis=1)  # [n_mels]
        
        # Compute centroid (weighted average of mel bin indices)
        mel_bins = np.arange(len(freq_profile))
        # Use exponential of dB values as weights (convert back from log scale)
        weights = np.exp(freq_profile / 10.0)
        weights = np.maximum(weights, 1e-10)  # Avoid division by zero
        centroid = np.sum(mel_bins * weights) / np.sum(weights)
        centroids[class_id].append(centroid)

print(f"Computed spectral centroids for {n_samples_centroid} samples per class")

In [None]:
# Plot spectral centroid distributions
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: Violin plot
ax = axes[0]
parts = ax.violinplot([centroids[i] for i in range(3)], positions=range(3), 
                       showmeans=True, showmedians=True)
for pc, color in zip(parts['bodies'], colors):
    pc.set_facecolor(color)
    pc.set_alpha(0.7)

ax.set_xticks(range(3))
ax.set_xticklabels(LABELS)
ax.set_ylabel('Spectral Centroid (Mel Bin)', fontsize=11)
ax.set_title('Spectral Centroid Distribution', fontsize=12, fontweight='bold')
ax.grid(axis='y', alpha=0.3)

# Right: Box plot with statistical comparison
ax = axes[1]
bp = ax.boxplot([centroids[i] for i in range(3)], labels=LABELS, patch_artist=True)
for patch, color in zip(bp['boxes'], colors):
    patch.set_facecolor(color)
    patch.set_alpha(0.7)

ax.set_ylabel('Spectral Centroid (Mel Bin)', fontsize=11)
ax.set_title('Spectral Centroid Comparison', fontsize=12, fontweight='bold')
ax.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

# Print statistics
print("\nSpectral Centroid Statistics:")
print("-" * 60)
for i, name in enumerate(LABELS):
    mean_cent = np.mean(centroids[i])
    std_cent = np.std(centroids[i])
    median_cent = np.median(centroids[i])
    print(f"{name}:")
    print(f"  Mean:   {mean_cent:.2f} ± {std_cent:.2f}")
    print(f"  Median: {median_cent:.2f}")
    print()

# Interpretation
print("Interpretation:")
print("  Lower centroid → More low-frequency content")
print("  Higher centroid → More high-frequency content")

### 4. Class Separability Analysis

Measure how well classes can be distinguished based on acoustic features.  
Higher separability = easier classification.

In [None]:
# Compute pairwise class separability using Bhattacharyya distance
def bhattacharyya_distance(samples1, samples2):
    """Approximate Bhattacharyya distance between two distributions."""
    mean1, std1 = np.mean(samples1), np.std(samples1) + 1e-6
    mean2, std2 = np.mean(samples2), np.std(samples2) + 1e-6
    
    # Bhattacharyya coefficient
    bc = 0.25 * np.log(0.25 * ((std1**2/std2**2) + (std2**2/std1**2) + 2))
    bc += 0.25 * ((mean1 - mean2)**2 / (std1**2 + std2**2))
    return bc

# Pre-compute class indices (MUCH FASTER)
print("Finding class indices...")
class_indices_dict = {0: [], 1: [], 2: []}
n_per_class = 300

for idx in range(len(labels)):
    label = int(labels[idx])
    if len(class_indices_dict[label]) < n_per_class:
        class_indices_dict[label].append(idx)
    if all(len(v) >= n_per_class for v in class_indices_dict.values()):
        break

print(f"Found {n_per_class} samples per class")

# Compute separability matrix
print("Computing separability matrix...")
separability_matrix = np.zeros((3, 3))
feature_name = "mean_energy"

for i in range(3):
    for j in range(3):
        if i != j:
            # Use pre-computed indices (FAST)
            feat_i = [specs[idx].mean() for idx in class_indices_dict[i]]
            feat_j = [specs[idx].mean() for idx in class_indices_dict[j]]
            separability_matrix[i, j] = bhattacharyya_distance(feat_i, feat_j)

print(f"\nPairwise Class Separability (Bhattacharyya Distance on {feature_name}):")
print("-" * 60)
for i in range(3):
    for j in range(i+1, 3):
        print(f"{LABELS[i]} vs {LABELS[j]}: {separability_matrix[i, j]:.4f}")

In [None]:
# Visualize separability matrix
fig, ax = plt.subplots(figsize=(8, 6))

# Make symmetric
sep_sym = separability_matrix + separability_matrix.T
im = ax.imshow(sep_sym, cmap='YlOrRd', aspect='auto')

ax.set_xticks(range(3))
ax.set_yticks(range(3))
ax.set_xticklabels(LABELS)
ax.set_yticklabels(LABELS)

# Add text annotations
for i in range(3):
    for j in range(3):
        if i != j:
            text = ax.text(j, i, f'{sep_sym[i, j]:.3f}',
                          ha="center", va="center", color="black", fontsize=11)

ax.set_title('Pairwise Class Separability Matrix', fontsize=13, fontweight='bold')
plt.colorbar(im, ax=ax, label='Bhattacharyya Distance')
plt.tight_layout()
plt.show()

print("\nInterpretation:")
print("  Higher values indicate better separability between classes")
print("  Values > 0.1 suggest good discriminability")

In [None]:
# Close file
h5.close()
print("Done!")