# Notebook 04: Full Dataset Generation
**Project:** Synthetic Sleep Environment Dataset Generator  
**Authors:** Rushav Dash & Lisa Li  
**Course:** TECHIN 513 — Signal Processing & Machine Learning  
**University:** University of Washington  
**Date:** 2026-02-19

## Table of Contents
1. [Setup & Imports](#section-1)
2. [Pipeline Overview](#section-2)
3. [Initialize Generator](#section-3)
4. [Run Setup (Download + Train)](#section-4)
5. [Generate 5,000 Sessions](#section-5)
6. [Preview Output](#section-6)
7. [Diversity Checks](#section-7)
8. [Signal Shape Verification](#section-8)
9. [Save Dataset](#section-9)
10. [Summary Statistics](#section-10)

---
## 1. Setup & Imports <a id='section-1'></a>

In [None]:
import sys, os
PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), '..'))
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from src.dataset_generator import SleepDatasetGenerator

%matplotlib inline
plt.rcParams.update({'figure.dpi': 120, 'font.size': 11})
sns.set_theme(style='whitegrid')
print('Setup complete.')

---
## 2. Pipeline Overview <a id='section-2'></a>

The `SleepDatasetGenerator` orchestrates four stages:

```
DataLoader  →  download + calibrate reference stats from real IoT data
    ↓
SignalGenerator  →  generate 96-point time-series per signal per session
    ↓
FeatureExtractor  →  extract 29+ scalar features from each time-series
    ↓
SleepQualityModel  →  assign realistic sleep quality labels via Random Forest
```

**Stratification:** 5,000 sessions split equally across 4 seasons (1,250 each).
Within each season, age group (young/middle/senior) and sensitivity (low/normal/high)
are cycled uniformly.

---
## 3. Initialize Generator <a id='section-3'></a>

In [None]:
# global_seed=42 makes the full dataset exactly reproducible
gen = SleepDatasetGenerator(
    global_seed=42,
    n_sessions=5000,
    include_sound=True,
    include_humidity=True,
    verbose=True,
)
print('Generator initialised.')

---
## 4. Run Setup: Download Datasets & Train ML Model <a id='section-4'></a>

`gen.setup()` does three things:
1. Downloads all Kaggle datasets via kagglehub (cached after first run)
2. Extracts calibration statistics from real IoT data
3. Trains Random Forest models on the Sleep Efficiency dataset

In [None]:
# This cell downloads data and trains the model.
# Subsequent runs use cached data and pre-trained models.
gen.setup()
print('\nSetup complete!')

---
## 5. Generate 5,000 Sessions <a id='section-5'></a>

The generation loop shows a tqdm progress bar and logs every 500 sessions.

In [None]:
df = gen.generate()
print(f'\nGenerated dataset shape: {df.shape}')

---
## 6. Preview Output <a id='section-6'></a>

In [None]:
# Drop time-series columns for display (they're long JSON strings)
ts_cols = [c for c in df.columns if c.startswith('ts_')]
df_display = df.drop(columns=ts_cols)
df_display.head(10)

In [None]:
print('Column list:')
for col in df.columns:
    print(f'  {col:<45} dtype={df[col].dtype}')

In [None]:
print('Descriptive statistics for sleep quality labels:')
df[['sleep_efficiency','awakenings','rem_pct','deep_pct','light_pct']].describe().round(4)

---
## 7. Diversity Checks <a id='section-7'></a>

Verify the stratification is correct across seasons, age groups, and sensitivity levels.

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Season distribution
season_counts = df['season'].value_counts().reindex(['winter','spring','summer','fall'])
axes[0].bar(season_counts.index, season_counts.values, color='steelblue', edgecolor='white')
axes[0].set_title('Sessions by Season')
axes[0].set_ylabel('Count')
axes[0].axhline(1250, color='red', linestyle='--', alpha=0.6, label='Target: 1250')
axes[0].legend()

# Age group distribution
age_counts = df['age_group'].value_counts().reindex(['young','middle','senior'])
axes[1].bar(age_counts.index, age_counts.values, color='coral', edgecolor='white')
axes[1].set_title('Sessions by Age Group')

# Sensitivity distribution
sens_counts = df['sensitivity'].value_counts().reindex(['low','normal','high'])
axes[2].bar(sens_counts.index, sens_counts.values, color='mediumseagreen', edgecolor='white')
axes[2].set_title('Sessions by Sensitivity')

fig.suptitle('Dataset Stratification Verification', fontsize=13)
plt.tight_layout()
plt.show()

---
## 8. Signal Shape Verification <a id='section-8'></a>

Plot a 3×3 grid of randomly selected temperature time-series to visually confirm
signal diversity and realism.

In [None]:
import numpy as np

rng_sample = np.random.default_rng(0)
sample_idx = rng_sample.choice(len(df), size=9, replace=False)

fig, axes = plt.subplots(3, 3, figsize=(14, 9), sharex=True)
t = np.arange(0, 480, 5)  # time axis: 96 points, 5-min intervals

for ax, idx in zip(axes.flatten(), sample_idx):
    row = df.iloc[idx]
    temp_ts = json.loads(row['ts_temperature'])
    ax.plot(t, temp_ts, linewidth=1.2, color='tomato')
    ax.axhspan(18, 21, alpha=0.15, color='green', label='Optimal zone')
    ax.set_ylim(14, 30)
    ax.set_title(
        f"Session {idx} | {row['season']} | {row['age_group']}\n"
        f"eff={row['sleep_efficiency']:.2f}  awk={row['awakenings']}",
        fontsize=8
    )
    ax.set_ylabel('°C', fontsize=8)
    ax.set_xlabel('min', fontsize=8)

fig.suptitle('Temperature Time-Series: 9 Random Sessions\n'
             '(green band = optimal sleep zone 18–21°C)', fontsize=12)
plt.tight_layout()
plt.show()

---
## 9. Save Dataset <a id='section-9'></a>

In [None]:
csv_path, json_path = gen.save(df)
print(f'CSV  saved: {csv_path}')
print(f'JSON saved: {json_path}')

---
## 10. Summary Statistics <a id='section-10'></a>

In [None]:
print('=== DATASET SUMMARY ===')
print(f'Total sessions      : {len(df):,}')
print(f'Total columns       : {len(df.columns)}')
print(f'Time-series cols    : {len(ts_cols)}')
print(f'Feature cols        : {len([c for c in df.columns if c not in ts_cols + ["session_id","session_index","season","age_group","sensitivity","random_seed","sleep_efficiency","awakenings","rem_pct","deep_pct","light_pct"]])}')
print()
print('Sleep quality label ranges:')
for col in ['sleep_efficiency','awakenings','rem_pct','deep_pct','light_pct']:
    print(f'  {col:<22}: [{df[col].min():.2f}, {df[col].max():.2f}]  mean={df[col].mean():.3f}')
print()
print(f'Stage sum check (should be ~100): {(df["rem_pct"]+df["deep_pct"]+df["light_pct"]).mean():.2f}')