# Data Story: DreamCatcher Cache + Preprocessing

This notebook explains the dataset subset and preprocessing artifacts used by the benchmark.

**Run order:**
1. `python3 scripts/preprocess.py`
2. Open and run this notebook

**Dependencies:** only local cache artifacts under `results/cache/spectrograms/`.


In [None]:
from pathlib import Path

import h5py
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

np.random.seed(42)

CWD = Path.cwd()
if (CWD / 'results').exists():
    ROOT = CWD
elif (CWD.parent / 'results').exists():
    ROOT = CWD.parent
else:
    raise FileNotFoundError('Could not locate repo root containing results/.')
CACHE_DIR = ROOT / 'results' / 'cache' / 'spectrograms'
SPLITS = ['train', 'validation', 'test']
LABELS = ['quiet', 'breathe', 'snore']
COLORS = ['#2ecc71', '#3498db', '#e74c3c']

plt.rcParams['figure.dpi'] = 120
plt.rcParams['axes.grid'] = True


## A) Dataset Scope and Reproducibility Context

We use the 3-class subset: `quiet`, `breathe`, `snore`.
The cells below verify cached split metadata (`n_samples`, `n_mels`, `sample_rate`, `max_time`).


In [None]:
meta_rows = []
split_labels = {}

for split in SPLITS:
    p = CACHE_DIR / f'{split}.h5'
    if not p.exists():
        raise FileNotFoundError(f'Missing cache file: {p}')
    with h5py.File(p, 'r') as h5:
        labels = h5['labels'][:]
        split_labels[split] = labels
        meta_rows.append(
            {
                'split': split,
                'n_samples': int(h5.attrs['n_samples']),
                'n_mels': int(h5.attrs['n_mels']),
                'max_time': int(h5.attrs['max_time']),
                'sample_rate': int(h5.attrs['sample_rate']),
            }
        )

meta_df = pd.DataFrame(meta_rows)
display(meta_df)

unique_n_mels = sorted(meta_df['n_mels'].unique().tolist())
print('unique n_mels:', unique_n_mels)
assert unique_n_mels == [64], f'Expected n_mels=64, got {unique_n_mels}'


## B) Split-wise Class Distribution


In [None]:
dist_rows = []
for split, labels in split_labels.items():
    total = len(labels)
    for class_id, class_name in enumerate(LABELS):
        count = int((labels == class_id).sum())
        dist_rows.append(
            {
                'split': split,
                'class_id': class_id,
                'class_name': class_name,
                'count': count,
                'pct': 100.0 * count / total,
            }
        )

dist_df = pd.DataFrame(dist_rows)
display(dist_df)


In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Absolute counts
count_pivot = dist_df.pivot(index='split', columns='class_name', values='count')[LABELS]
count_pivot.plot(kind='bar', stacked=True, ax=axes[0], color=COLORS)
axes[0].set_title('Class Counts by Split')
axes[0].set_ylabel('samples')
axes[0].legend(title='class')

# Percentages
pct_pivot = dist_df.pivot(index='split', columns='class_name', values='pct')[LABELS]
pct_pivot.plot(kind='bar', stacked=True, ax=axes[1], color=COLORS)
axes[1].set_title('Class Ratio by Split')
axes[1].set_ylabel('percent (%)')
axes[1].legend(title='class')

plt.tight_layout()
plt.show()


## C) Sample Spectrograms by Class (Train Split)


In [None]:
train_path = CACHE_DIR / 'train.h5'
with h5py.File(train_path, 'r') as h5:
    train_specs = h5['spectrograms']
    train_labels = h5['labels'][:]
    n_mels = train_specs.shape[1]

    per_class_idx = {}
    for class_id in range(len(LABELS)):
        candidates = np.where(train_labels == class_id)[0]
        per_class_idx[class_id] = np.random.choice(candidates, size=3, replace=False)

    fig, axes = plt.subplots(len(LABELS), 3, figsize=(12, 8), sharex=True, sharey=True)
    mel_ticks = np.linspace(0, n_mels - 1, num=5, dtype=int)

    for row, class_id in enumerate(range(len(LABELS))):
        for col, idx in enumerate(per_class_idx[class_id]):
            spec = train_specs[int(idx)]
            ax = axes[row, col]
            im = ax.imshow(spec, origin='lower', aspect='auto', cmap='magma')
            if col == 0:
                ax.set_ylabel(LABELS[class_id])
            if row == len(LABELS) - 1:
                ax.set_xlabel('time frame')
            ax.set_yticks(mel_ticks)
            ax.set_yticklabels([str(int(t)) for t in mel_ticks])

    fig.suptitle('Random Spectrogram Samples per Class (train)', y=1.02)
    cbar = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.7)
    cbar.set_label('log-mel (dB)')
    plt.tight_layout()
    plt.show()


## D) Preprocessing Diagnostics

Checks:
- clip length / time-frame consistency
- dB quantiles
- finite values (`NaN`/`inf`) on random samples


In [None]:
diag_rows = []

for split in SPLITS:
    p = CACHE_DIR / f'{split}.h5'
    with h5py.File(p, 'r') as h5:
        specs = h5['spectrograms']
        n = specs.shape[0]
        sample_n = min(2000, n)
        idx = np.sort(np.random.choice(n, size=sample_n, replace=False))
        sample = specs[idx]

        finite_ratio = np.isfinite(sample).mean()
        diag_rows.append(
            {
                'split': split,
                'shape': tuple(specs.shape),
                'finite_ratio': float(finite_ratio),
                'db_min': float(np.min(sample)),
                'db_q01': float(np.quantile(sample, 0.01)),
                'db_q50': float(np.quantile(sample, 0.50)),
                'db_q99': float(np.quantile(sample, 0.99)),
                'db_max': float(np.max(sample)),
            }
        )

diag_df = pd.DataFrame(diag_rows)
display(diag_df)

assert (diag_df['finite_ratio'] == 1.0).all(), 'Found non-finite values in sampled cache tensors.'


## E) Compact Class-Level Acoustic Summaries

- Average spectrogram per class (shared color scale)
- Temporal energy profile with variability band


In [None]:
n_per_class = 400
avg_specs = {}
temporal_mean = {}
temporal_std = {}

with h5py.File(CACHE_DIR / 'train.h5', 'r') as h5:
    specs = h5['spectrograms']
    labels = h5['labels'][:]

    for class_id in range(len(LABELS)):
        idx = np.where(labels == class_id)[0][:n_per_class]
        class_specs = specs[idx]
        avg_specs[class_id] = class_specs.mean(axis=0)
        temp_profiles = class_specs.mean(axis=1)  # [N, time]
        temporal_mean[class_id] = temp_profiles.mean(axis=0)
        temporal_std[class_id] = temp_profiles.std(axis=0)

global_min = min(arr.min() for arr in avg_specs.values())
global_max = max(arr.max() for arr in avg_specs.values())

fig, axes = plt.subplots(1, len(LABELS), figsize=(12, 3.5), sharex=True, sharey=True)
for i, class_id in enumerate(range(len(LABELS))):
    ax = axes[i]
    im = ax.imshow(
        avg_specs[class_id],
        origin='lower',
        aspect='auto',
        cmap='viridis',
        vmin=global_min,
        vmax=global_max,
    )
    ax.set_title(LABELS[class_id])
    ax.set_xlabel('time frame')
    if i == 0:
        ax.set_ylabel('mel bin')

cbar = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.8)
cbar.set_label('avg log-mel (dB)')
plt.tight_layout()
plt.show()

t = np.arange(next(iter(temporal_mean.values())).shape[0])
fig, ax = plt.subplots(figsize=(10, 4))
for class_id, color in enumerate(COLORS):
    m = temporal_mean[class_id]
    s = temporal_std[class_id]
    ax.plot(t, m, color=color, label=LABELS[class_id], lw=2)
    ax.fill_between(t, m - s, m + s, color=color, alpha=0.18)

ax.set_title('Temporal Energy Profile (train)')
ax.set_xlabel('time frame')
ax.set_ylabel('mean energy (dB)')
ax.legend()
plt.tight_layout()
plt.show()


## Notes

- This notebook is intentionally compact and preprocessing-focused.
- Model comparison and KD interpretation are in `notebooks/results_analysis.ipynb`.
