# Cellclass + Models Workflow (Examples)

This notebook shows a practical workflow from processed data to model comparison and discrepancy review.

It is organized as:
1. Setup
2. Aggregate units by age group
3. Run fixed `k=2` comparison against `type_u`
4. Inspect discrepancy summaries
5. Build per-neuron disagreement review artifacts
6. Quick plotting examples

In [None]:
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt

RESULTS_ROOT = Path("../results")
TYPEU_OUT = RESULTS_ROOT / "type_u_comparison"
REVIEW_OUT = TYPEU_OUT / "review"
AGE_GROUPS = ["P15_16", "P17_18", "P19_20", "P21_22", "P23_24", "P25"]

RESULTS_ROOT, TYPEU_OUT

## 1) Aggregate by Age Group

Run this in terminal (PowerShell/CMD) from repository root.

```bash
python scripts/aggreggate_by_age.py --qc
```

This writes per-age files under `results/<AGE>/` such as:
- `<AGE>_all_units.parquet`
- `<AGE>_clean_units.parquet`
- `<AGE>_ml_units.parquet`
- `<AGE>_X.npz`
- `<AGE>_ml_meta.json`

Quick check below:

In [None]:
rows = []
for g in AGE_GROUPS:
    p = RESULTS_ROOT / g / f"{g}_ml_units.parquet"
    if p.exists():
        d = pd.read_parquet(p)
        rows.append({"age_group": g, "n_units": len(d), "n_sessions": d["session_id"].nunique()})

pd.DataFrame(rows).sort_values("age_group")

## 2) Compare GMM (`k=2`) vs `type_u`

Example command (all default features):

```bash
python src/models/compare_type_u.py \
  --results_root results \
  --out_root results/type_u_comparison \
  --n_init 20
```

Example command (custom feature set):

```bash
python src/models/compare_type_u.py \
  --results_root results \
  --out_root results/type_u_comparison_smallfeat \
  --features fr_hz,burst_index,cv2,spk_duration_ms,spk_asymmetry \
  --n_init 20
```

## 3) Read discrepancy summaries

In [None]:
by_age = pd.read_csv(TYPEU_OUT / "discrepancies_by_age.csv")
by_session = pd.read_csv(TYPEU_OUT / "discrepancies_by_session.csv")

display(by_age.sort_values("discrepancy_rate", ascending=False))
display(by_session.head(15))

In [None]:
fig, ax = plt.subplots(figsize=(8, 4))
tmp = by_age.sort_values("age_group")
ax.bar(tmp["age_group"], tmp["discrepancy_rate"])
ax.set_ylabel("discrepancy rate")
ax.set_title("Type_u disagreement by age group")
plt.xticks(rotation=0)
plt.tight_layout()
plt.show()

## 4) Full list of disagreeing neurons

After the latest update, `compare_type_u.py` writes:
- `discrepant_neurons.parquet` (all)
- `discrepant_neurons_all.csv` (all)
- `discrepant_neurons_top500.csv` (subset)

Load and inspect:

In [None]:
dis = pd.read_csv(TYPEU_OUT / "discrepant_neurons_all.csv")
print("n_discrepant:", len(dis))
dis.head(10)

## 5) Generate per-neuron review artifacts (waveform + context)

Run from terminal:

```bash
python src/models/disagreement_review.py \
  --comparison_root results/type_u_comparison \
  --interim_root data/interim \
  --out_root results/type_u_comparison/review
```

Optional:
- `--ages P23_24`
- `--top_n 200`

Outputs:
- `discrepant_neurons_review.csv`
- `discrepant_neurons_review.parquet`
- one PNG per neuron under `review/plots/<AGE>/`

In [None]:
review_csv = REVIEW_OUT / "discrepant_neurons_review.csv"
if review_csv.exists():
    r = pd.read_csv(review_csv)
    display(r.head(10))
    print("Total reviewed neurons:", len(r))
else:
    print(f"Missing: {review_csv} (run disagreement_review.py first)")

## 6) Quick scatter: disagreements in feature space

In [None]:
all_df = pd.read_parquet(TYPEU_OUT / "all_age_groups_gmm2_vs_type_u.parquet")
cmp = all_df[all_df["comparable"] == True].copy()

fig, axes = plt.subplots(nrows=6, ncols=1, figsize=(8, 24), sharex=True, sharey=True)
for i, g in enumerate(AGE_GROUPS):
    ax = axes[i]
    d = cmp[cmp["age_group"] == g]
    ok = d[d["discrepancy"] == False]
    bad = d[d["discrepancy"] == True]

    ax.scatter(ok["spk_duration_ms"], ok["fr_hz"], s=8, alpha=0.25, color="gray", label="match")
    ax.scatter(bad["spk_duration_ms"], bad["fr_hz"], s=14, alpha=0.85, color="red", label="discrepancy")
    ax.set_title(g)
    ax.set_ylabel("fr_hz")
    if i == 0:
        ax.legend(frameon=False, fontsize=8)

axes[-1].set_xlabel("spk_duration_ms")
plt.tight_layout()
plt.show()