# Accents-PT-BR — Accent Classifier Ablation (CNN vs wav2vec2)

**Projeto:** Controle Explícito de Sotaque Regional em pt-BR  
**Objetivo:** Treinar e avaliar classificadores de sotaque (CNN mel-spectrogram vs wav2vec2 fine-tuned) no dataset combinado Accents-PT-BR (CORAA-MUPE + Common Voice PT). Esses classificadores servem como **avaliadores externos** para os Stages 2-3 (medir se o áudio gerado pelo LoRA carrega o sotaque-alvo).  
**Config:** `configs/accent_classifier.yaml` (single source of truth).  
**Dataset:** Accents-PT-BR = CORAA-MUPE (entrevistados) + Common Voice PT (accent label normalizado).  

**Seções:**
1. Setup do ambiente
2. CORAA-MUPE manifest
3. Common Voice PT manifest
4. Dataset combinado Accents-PT-BR
5. Análise de confounds (incluindo accent x source)
6. Speaker-disjoint splits
7. CNN accent classifier (treinamento + avaliação)
8. wav2vec2 accent classifier (treinamento + avaliação)
9. Cross-source evaluation (confound check)
10. Ablation summary + report

Este notebook é a **camada de orquestração**. Toda lógica está em `src/` (testável, auditável).  
O notebook apenas: instala deps → configura ambiente → chama módulos → exibe resultados.

In [None]:
import os, subprocess, sys

# --- Platform-aware setup: works on Colab, Lightning.ai, and local ---
# Detection order: Lightning.ai -> Google Colab -> Local

# 1. Determine repo directory
_lightning_studio = '/teamspace/studios/this_studio'
if os.path.exists(_lightning_studio):
    REPO_DIR = os.path.join(_lightning_studio, 'TCC')
    _platform = 'lightning'
elif 'google.colab' in sys.modules or os.path.exists('/content'):
    REPO_DIR = '/content/TCC'
    _platform = 'colab'
else:
    REPO_DIR = os.getcwd()
    _platform = 'local'

# 2. Clone repo if needed (idempotent)
if not os.path.exists(os.path.join(REPO_DIR, '.git')):
    subprocess.run(['rm', '-rf', REPO_DIR], check=False)
    subprocess.run(
        ['git', 'clone', 'https://github.com/paulohenriquevn/tcc.git', REPO_DIR],
        check=True,
    )

os.chdir(REPO_DIR)
if REPO_DIR not in sys.path:
    sys.path.insert(0, REPO_DIR)

# 3. Install dependencies
subprocess.run([sys.executable, '-m', 'pip', 'install', '-r', 'requirements.txt', '-q'], check=True)

# 4. NumPy ABI check — Colab pre-loads numpy 2.x in memory, but
#    requirements.txt pins 1.26.4. After pip downgrades, stale C-extensions
#    cause binary incompatibility. Fix: restart runtime ONCE.
_installed_np = subprocess.check_output(
    [sys.executable, '-c', 'import numpy; print(numpy.__version__)'],
    text=True,
).strip()

try:
    import numpy as _np
    _loaded_np = _np.__version__
except Exception:
    _loaded_np = None

if _loaded_np != _installed_np:
    print(f'\nNumPy ABI mismatch: loaded={_loaded_np}, installed={_installed_np}')
    print('Restarting runtime... After restart, re-run this cell (no second restart).')
    os.kill(os.getpid(), 9)
else:
    print(f'\nPlatform: {_platform}')
    print(f'Repo: {REPO_DIR}')
    print(f'Environment OK (numpy=={_installed_np})')

In [None]:
import sys, os, yaml, json, logging
from pathlib import Path
from collections import Counter

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Platform-aware persistent cache setup
from src.utils.platform import detect_platform, setup_environment

platform = detect_platform()
setup_environment(platform)

# Mount Google Drive only on Colab (Lightning.ai has persistent disk built-in)
if platform.needs_drive_mount:
    from google.colab import drive
    drive.mount('/content/drive')

from src.utils.seed import set_global_seed
from src.data.manifest import (
    ManifestEntry, read_manifest, write_manifest,
    normalize_cv_accent, compute_file_hash,
)
from src.data.manifest_builder import build_manifest_from_hf_dataset
from src.data.cv_manifest_builder import build_manifest_from_common_voice
from src.data.combined_manifest import combine_manifests, analyze_source_distribution
from src.data.splits import (
    generate_speaker_disjoint_splits,
    save_splits,
    assign_entries_to_splits,
)
from src.analysis.confounds import run_all_confound_checks
from src.classifier import (
    AccentCNN, AccentWav2Vec2,
    train_classifier, evaluate_classifier,
    TrainingConfig, TrainingResult,
)
from src.classifier.mel_dataset import MelSpectrogramDataset
from src.classifier.wav2vec2_dataset import WaveformDataset
from src.classifier.trainer import compute_class_weights

# Load config — single source of truth for all experiment parameters
with open('configs/accent_classifier.yaml') as f:
    config = yaml.safe_load(f)

SEED = config['seed']['global']
generator = set_global_seed(SEED)

logging.basicConfig(
    level=logging.INFO,
    format='%(name)s - %(levelname)s - %(message)s',
)

print(f'Platform: {platform.name}')
print(f'Config loaded: {config["experiment"]["name"]}')
print(f'Seed global: {SEED}')

In [None]:
# Environment check: GPU, CUDA, PyTorch versions
print(f'Python: {sys.version}')
print(f'PyTorch: {torch.__version__}')
print(f'CUDA available: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'CUDA device: {torch.cuda.get_device_name(0)}')
    print(f'CUDA version: {torch.version.cuda}')
    print(f'VRAM total: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'\nUsando device: {DEVICE}')

# Drive cache base directory — platform-aware
DRIVE_BASE = platform.cache_base
DRIVE_BASE.mkdir(parents=True, exist_ok=True)
print(f'Cache base: {DRIVE_BASE}')

## CORAA-MUPE Manifest

Download CORAA-MUPE-ASR from HuggingFace, apply filters, and build the manifest JSONL.  

**Filters:**  
- `speaker_type='R'` (interviewees only)  
- Duration: 3–15s  
- `birth_state` valido → macro-regiao IBGE (N, NE, CO, SE, S)  
- `min_speakers_per_region`: 8  
- `min_utterances_per_speaker`: 3  

**Nota:** O download inicial e ~42 GB. Runs subsequentes usam cache do Drive.

In [None]:
from datasets import load_dataset, concatenate_datasets

# Paths for CORAA-MUPE manifest and audio
CORAA_AUDIO_DIR = DRIVE_BASE / 'coraa_mupe' / 'audio'
CORAA_MANIFEST_PATH = DRIVE_BASE / 'coraa_mupe' / 'manifest.jsonl'

if CORAA_MANIFEST_PATH.exists():
    print(f'Loading CORAA-MUPE manifest from cache: {CORAA_MANIFEST_PATH}')
    coraa_entries = read_manifest(CORAA_MANIFEST_PATH)
    coraa_sha256 = compute_file_hash(CORAA_MANIFEST_PATH)
    print(f'Loaded {len(coraa_entries):,} entries (SHA-256: {coraa_sha256[:16]}...)')
else:
    print('Downloading CORAA-MUPE-ASR from HuggingFace...')
    print('(~42 GB na primeira vez)')

    ds = load_dataset('nilc-nlp/CORAA-MUPE-ASR')
    print(f'Splits disponiveis: {list(ds.keys())}')
    for split_name, split_data in ds.items():
        print(f'  {split_name}: {len(split_data):,} rows')

    # Concatenate all splits — we create our own speaker-disjoint splits
    all_data = concatenate_datasets([ds[split] for split in ds.keys()])
    print(f'\nTotal concatenado: {len(all_data):,} rows')

    coraa_entries, coraa_stats = build_manifest_from_hf_dataset(
        dataset=all_data,
        audio_output_dir=CORAA_AUDIO_DIR,
        manifest_output_path=CORAA_MANIFEST_PATH,
        speaker_type_filter=config['dataset']['filters'].get('speaker_type', 'R'),
        min_duration_s=config['dataset']['filters']['min_duration_s'],
        max_duration_s=config['dataset']['filters']['max_duration_s'],
        min_speakers_per_region=config['dataset']['filters']['min_speakers_per_region'],
        min_utterances_per_speaker=config['dataset']['filters'].get('min_utterances_per_speaker', 3),
    )
    coraa_sha256 = coraa_stats['manifest_sha256']

    print(f'\nCORAA-MUPE manifest: {len(coraa_entries):,} entries')
    print(f'SHA-256: {coraa_sha256}')
    for region, info in coraa_stats.get('regions', {}).items():
        print(f'  {region}: {info["n_speakers"]} speakers, {info["n_utterances"]:,} utts')

# Summary
region_counts = Counter(e.accent for e in coraa_entries)
print(f'\nCORAA-MUPE: {len(coraa_entries):,} entries, regions: {dict(sorted(region_counts.items()))}')

## Common Voice PT Manifest

Load Common Voice Portuguese (v17.0) from HuggingFace and build manifest.  
Speaker IDs are prefixed with `cv_` to avoid collisions with CORAA-MUPE.  
The `accent` field in Common Voice is user-submitted and noisy — `normalize_cv_accent()` handles the mapping.

In [None]:
# Paths for Common Voice manifest and audio
CV_AUDIO_DIR = DRIVE_BASE / 'common_voice_pt' / 'audio'
CV_MANIFEST_PATH = DRIVE_BASE / 'common_voice_pt' / 'manifest.jsonl'

if CV_MANIFEST_PATH.exists():
    print(f'Loading Common Voice PT manifest from cache: {CV_MANIFEST_PATH}')
    cv_entries = read_manifest(CV_MANIFEST_PATH)
    cv_sha256 = compute_file_hash(CV_MANIFEST_PATH)
    print(f'Loaded {len(cv_entries):,} entries (SHA-256: {cv_sha256[:16]}...)')
else:
    print('Loading Common Voice PT from HuggingFace...')
    print('(mozilla-foundation/common_voice_17_0, lang=pt)')

    cv_hf_id = config['dataset']['sources'][1]['hf_id']
    cv_lang = config['dataset']['sources'][1]['hf_lang']

    # Load the validated split (most reliable labels)
    cv_dataset = load_dataset(cv_hf_id, cv_lang, split='validated')
    print(f'Common Voice validated split: {len(cv_dataset):,} rows')
    print(f'Columns: {cv_dataset.column_names}')

    cv_entries, cv_stats = build_manifest_from_common_voice(
        dataset=cv_dataset,
        audio_output_dir=CV_AUDIO_DIR,
        manifest_output_path=CV_MANIFEST_PATH,
        min_duration_s=config['dataset']['filters']['min_duration_s'],
        max_duration_s=config['dataset']['filters']['max_duration_s'],
        min_speakers_per_region=config['dataset']['filters']['min_speakers_per_region'],
        min_utterances_per_speaker=config['dataset']['filters'].get('min_utterances_per_speaker', 3),
    )
    cv_sha256 = cv_stats['manifest_sha256']

    print(f'\nCommon Voice PT manifest: {len(cv_entries):,} entries')
    print(f'SHA-256: {cv_sha256}')
    for region, info in cv_stats.get('regions', {}).items():
        print(f'  {region}: {info["n_speakers"]} speakers, {info["n_utterances"]:,} utts')

# Summary
region_counts_cv = Counter(e.accent for e in cv_entries)
print(f'\nCommon Voice PT: {len(cv_entries):,} entries, regions: {dict(sorted(region_counts_cv.items()))}')

## Combined Accents-PT-BR Dataset

Merge CORAA-MUPE and Common Voice manifests into a single dataset.  
Validates: no utt_id or speaker_id collisions across sources, speaker-accent consistency.

In [None]:
COMBINED_MANIFEST_PATH = DRIVE_BASE / 'accents_pt_br' / 'manifest.jsonl'

if COMBINED_MANIFEST_PATH.exists():
    print(f'Loading combined manifest from cache: {COMBINED_MANIFEST_PATH}')
    combined_entries = read_manifest(COMBINED_MANIFEST_PATH)
    combined_sha256 = compute_file_hash(COMBINED_MANIFEST_PATH)
    print(f'Loaded {len(combined_entries):,} entries (SHA-256: {combined_sha256[:16]}...)')
else:
    combined_entries, combined_stats = combine_manifests(
        manifests=[
            (CORAA_MANIFEST_PATH, 'CORAA-MUPE'),
            (CV_MANIFEST_PATH, 'CommonVoice-PT'),
        ],
        output_path=COMBINED_MANIFEST_PATH,
        min_speakers_per_region=config['dataset']['filters']['min_speakers_per_region'],
        min_utterances_per_speaker=config['dataset']['filters'].get('min_utterances_per_speaker', 3),
    )
    combined_sha256 = combined_stats['manifest_sha256']

    print(f'Combined manifest: {len(combined_entries):,} entries')
    print(f'SHA-256: {combined_sha256}')
    print(f'\nPer-source (input): {combined_stats["per_source_input"]}')
    print(f'Per-source (output): {combined_stats["per_source_output"]}')
    for region, info in combined_stats.get('regions', {}).items():
        print(f'  {region}: {info["n_speakers"]} speakers, {info["n_utterances"]:,} utts')

# Source distribution analysis
source_dist = analyze_source_distribution(combined_entries)

print(f'\n=== SOURCE DISTRIBUTION ===')
print(f'Source x Accent:')
for src, counts in source_dist['source_x_accent'].items():
    print(f'  {src}: {dict(sorted(counts.items()))}')

print(f'\nSource x Gender:')
for src, counts in source_dist['source_x_gender'].items():
    print(f'  {src}: {dict(sorted(counts.items()))}')

if source_dist['warnings']:
    print(f'\nWARNINGS:')
    for w in source_dist['warnings']:
        print(f'  {w}')

# Overall summary
total_speakers = len({e.speaker_id for e in combined_entries})
region_counts_all = Counter(e.accent for e in combined_entries)
print(f'\nTotal: {len(combined_entries):,} entries, {total_speakers} speakers')
print(f'Regions: {dict(sorted(region_counts_all.items()))}')

## Confound Analysis (including accent x source)

**Sanity checks obrigatorios** (protocolo):  
- Tabela accent x gender com chi-quadrado + Cramer's V  
- Histograma de duracao por regiao + Kruskal-Wallis  
- Accent x source: se uma regiao vem 80%+ de uma fonte, o classificador pode aprender fonte ao inves de sotaque  

SNR check desabilitado para velocidade (`check_snr=False`).

In [None]:
confound_results = run_all_confound_checks(
    combined_entries,
    gender_blocking_threshold=config['confounds']['accent_x_gender']['threshold_blocker'],
    duration_practical_diff_s=config['confounds']['accent_x_duration']['practical_diff_s'],
    check_snr=False,  # Skip SNR for speed
    source_blocking_threshold=config['confounds']['accent_x_source']['threshold_blocker'],
)

print('=== CONFOUND ANALYSIS ===')
blocking_found = False
for result in confound_results:
    if result.is_blocking:
        status = 'BLOCKING'
        blocking_found = True
    elif result.is_significant:
        status = 'SIGNIFICANT'
    else:
        status = 'OK'

    print(f'\n{result.variable_a} x {result.variable_b}: {status}')
    print(f'  Test: {result.test_name}')
    print(f'  Statistic: {result.statistic:.4f}')
    print(f'  p-value: {result.p_value:.6f}')
    print(f'  Effect size ({result.effect_size_name}): {result.effect_size:.4f}')
    print(f'  Interpretation: {result.interpretation}')

# Cross-tabulation: accent x gender
gender_table = pd.crosstab(
    [e.accent for e in combined_entries],
    [e.gender for e in combined_entries],
    margins=True,
)
print('\n=== ACCENT x GENDER TABLE ===')
print(gender_table)

# Cross-tabulation: accent x source
source_table = pd.crosstab(
    [e.accent for e in combined_entries],
    [e.source for e in combined_entries],
    margins=True,
)
print('\n=== ACCENT x SOURCE TABLE ===')
print(source_table)

if blocking_found:
    print('\n*** BLOCKING CONFOUND DETECTED. Review before proceeding. ***')
else:
    print('\nNo blocking confounds detected. Proceeding.')

## Speaker-Disjoint Splits

**Obrigatorio:** nenhum speaker aparece em mais de um split (train/val/test).  
Splits sao estratificados por sotaque para garantir representacao em todos os splits.  
Splits sao artefatos versionados — nao mudam dentro de um experimento.

In [None]:
split_info = generate_speaker_disjoint_splits(
    combined_entries,
    train_ratio=config['splits']['ratios']['train'],
    val_ratio=config['splits']['ratios']['val'],
    test_ratio=config['splits']['ratios']['test'],
    seed=config['splits']['seed'],
)

# Persist splits
split_output_dir = Path(config['splits']['output_dir'])
split_path = save_splits(split_info, split_output_dir)
print(f'Splits saved to: {split_path}')
print(f'Train: {len(split_info.train_speakers)} speakers, {split_info.utterances_per_split["train"]:,} utts')
print(f'Val:   {len(split_info.val_speakers)} speakers, {split_info.utterances_per_split["val"]:,} utts')
print(f'Test:  {len(split_info.test_speakers)} speakers, {split_info.utterances_per_split["test"]:,} utts')

# Assign entries to splits
split_entries = assign_entries_to_splits(combined_entries, split_info)

train_entries = split_entries['train']
val_entries = split_entries['val']
test_entries = split_entries['test']

# Verify speaker-disjoint (hard fail if violated)
train_spk = {e.speaker_id for e in train_entries}
val_spk = {e.speaker_id for e in val_entries}
test_spk = {e.speaker_id for e in test_entries}

assert len(train_spk & val_spk) == 0, 'Speaker leakage train -> val'
assert len(train_spk & test_spk) == 0, 'Speaker leakage train -> test'
assert len(val_spk & test_spk) == 0, 'Speaker leakage val -> test'
print('\nSpeaker-disjoint verification: PASSED')

# Distribution per split
for split_name, entries_list in split_entries.items():
    accent_dist = Counter(e.accent for e in entries_list)
    print(f'  {split_name}: {dict(sorted(accent_dist.items()))}')

## CNN Accent Classifier

3-block CNN operating on mel-spectrograms.  
Architecture: Conv2d -> BatchNorm -> ReLU -> MaxPool (x3) -> AdaptiveAvgPool -> Linear.  
Early stopping on validation balanced accuracy.  
Class-weighted CrossEntropyLoss for imbalanced accent distributions.

In [None]:
# Build label mapping (sorted, deterministic)
label_to_idx = MelSpectrogramDataset.build_label_mapping(combined_entries)
idx_to_label = {v: k for k, v in label_to_idx.items()}
n_classes = len(label_to_idx)
label_names = [idx_to_label[i] for i in range(n_classes)]

print(f'Classes ({n_classes}): {label_names}')
print(f'Label mapping: {label_to_idx}')

# Persist label_to_idx for reproducibility — ensures same mapping across runs
label_map_path = Path(config['output']['report_dir']) / 'label_to_idx.json'
label_map_path.parent.mkdir(parents=True, exist_ok=True)
with open(label_map_path, 'w') as f:
    json.dump(label_to_idx, f, indent=2)
print(f'Label mapping saved to: {label_map_path}')

# CNN hyperparameters from config
cnn_cfg = config['cnn']

# Create datasets
train_mel_ds = MelSpectrogramDataset(
    entries=train_entries,
    label_to_idx=label_to_idx,
    n_mels=cnn_cfg['n_mels'],
    max_frames=cnn_cfg['max_frames'],
)
val_mel_ds = MelSpectrogramDataset(
    entries=val_entries,
    label_to_idx=label_to_idx,
    n_mels=cnn_cfg['n_mels'],
    max_frames=cnn_cfg['max_frames'],
)
test_mel_ds = MelSpectrogramDataset(
    entries=test_entries,
    label_to_idx=label_to_idx,
    n_mels=cnn_cfg['n_mels'],
    max_frames=cnn_cfg['max_frames'],
)

print(f'\nMel datasets: train={len(train_mel_ds)}, val={len(val_mel_ds)}, test={len(test_mel_ds)}')

# Create DataLoaders with reproducible worker seeds
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(SEED)

cnn_batch_size = cnn_cfg['training']['batch_size']

train_mel_loader = torch.utils.data.DataLoader(
    train_mel_ds, batch_size=cnn_batch_size, shuffle=True,
    num_workers=2, worker_init_fn=seed_worker, generator=g, pin_memory=True,
)
val_mel_loader = torch.utils.data.DataLoader(
    val_mel_ds, batch_size=cnn_batch_size, shuffle=False,
    num_workers=2, pin_memory=True,
)
test_mel_loader = torch.utils.data.DataLoader(
    test_mel_ds, batch_size=cnn_batch_size, shuffle=False,
    num_workers=2, pin_memory=True,
)

# Compute class weights for imbalanced data
train_labels_cnn = [label_to_idx[e.accent] for e in train_entries]
cnn_class_weights = compute_class_weights(train_labels_cnn, n_classes)
print(f'CNN class weights: {cnn_class_weights.tolist()}')

In [None]:
# Train CNN
cnn_model = AccentCNN(
    n_classes=n_classes,
    n_mels=cnn_cfg['n_mels'],
    conv_channels=cnn_cfg['conv_channels'],
)

cnn_checkpoint_dir = Path(config['output']['checkpoint_dir']) / 'cnn'

cnn_training_config = TrainingConfig(
    learning_rate=cnn_cfg['training']['learning_rate'],
    batch_size=cnn_cfg['training']['batch_size'],
    n_epochs=cnn_cfg['training']['n_epochs'],
    patience=cnn_cfg['training']['patience'],
    device=DEVICE,
    seed=SEED,
    checkpoint_dir=cnn_checkpoint_dir,
    experiment_name='accent_cnn',
    use_amp=cnn_cfg['training']['use_amp'],
)

print(f'Training CNN: lr={cnn_training_config.learning_rate}, '
      f'epochs={cnn_training_config.n_epochs}, '
      f'patience={cnn_training_config.patience}')

cnn_result = train_classifier(
    model=cnn_model,
    train_loader=train_mel_loader,
    val_loader=val_mel_loader,
    config=cnn_training_config,
    class_weights=cnn_class_weights,
)

print(f'\nCNN training complete:')
print(f'  Best epoch: {cnn_result.best_epoch}')
print(f'  Best val bal_acc: {cnn_result.best_val_bal_acc:.4f}')
print(f'  Total epochs: {cnn_result.total_epochs_run}')
print(f'  Checkpoint: {cnn_result.best_checkpoint_path}')

# Plot training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.plot(cnn_result.train_losses, label='Train Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('CNN Training Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

ax2.plot(cnn_result.val_bal_accs, label='Val Balanced Accuracy', color='orange')
ax2.axhline(y=1.0/n_classes, color='red', linestyle='--', label=f'Chance ({1.0/n_classes:.2f})')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Balanced Accuracy')
ax2.set_title('CNN Validation Balanced Accuracy')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
figures_dir = Path(config['output']['figures_dir'])
figures_dir.mkdir(parents=True, exist_ok=True)
plt.savefig(figures_dir / 'cnn_training_curves.png', dpi=150)
plt.show()

In [None]:
# Evaluate CNN on test set
# Load best checkpoint (weights_only=False: checkpoint contains config dict)
checkpoint = torch.load(cnn_result.best_checkpoint_path, map_location=DEVICE, weights_only=False)
cnn_model.load_state_dict(checkpoint['model_state_dict'])

cnn_eval = evaluate_classifier(
    model=cnn_model,
    test_loader=test_mel_loader,
    label_names=label_names,
    device=DEVICE,
    n_bootstrap=config['evaluation']['bootstrap_n_samples'],
)

print('=== CNN TEST EVALUATION ===')
print(f'Balanced Accuracy: {cnn_eval["balanced_accuracy"]:.4f} '
      f'(CI 95%: [{cnn_eval["ci_95_lower"]:.4f}, {cnn_eval["ci_95_upper"]:.4f}])')
print(f'F1 Macro: {cnn_eval["f1_macro"]:.4f}')
print(f'Chance level: {1.0/n_classes:.4f}')
print(f'\nPer-class recall:')
for name, recall in cnn_eval['per_class_recall'].items():
    print(f'  {name}: {recall:.4f}')

# Display confusion matrix
cm = np.array(cnn_eval['confusion_matrix'])
cm_norm = cm.astype(float) / cm.sum(axis=1, keepdims=True)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

sns.heatmap(cm, annot=True, fmt='d', xticklabels=label_names,
            yticklabels=label_names, cmap='Blues', ax=ax1)
ax1.set_title('CNN Confusion Matrix (counts)')
ax1.set_xlabel('Predicted')
ax1.set_ylabel('True')

sns.heatmap(cm_norm, annot=True, fmt='.2f', xticklabels=label_names,
            yticklabels=label_names, cmap='Blues', ax=ax2)
ax2.set_title('CNN Confusion Matrix (row-normalized recall)')
ax2.set_xlabel('Predicted')
ax2.set_ylabel('True')

plt.tight_layout()
plt.savefig(figures_dir / 'cnn_confusion_matrix.png', dpi=150)
plt.show()

## wav2vec2 Accent Classifier

Pre-trained wav2vec2-base with frozen CNN feature extractor + fine-tuned transformer + linear head.  
Operates on raw waveforms (no mel-spectrogram preprocessing).  
Smaller batch size due to VRAM constraints.

In [None]:
# wav2vec2 hyperparameters from config
w2v_cfg = config['wav2vec2']

# Create waveform datasets
train_wav_ds = WaveformDataset(
    entries=train_entries,
    label_to_idx=label_to_idx,
    max_length_s=w2v_cfg['max_length_s'],
)
val_wav_ds = WaveformDataset(
    entries=val_entries,
    label_to_idx=label_to_idx,
    max_length_s=w2v_cfg['max_length_s'],
)
test_wav_ds = WaveformDataset(
    entries=test_entries,
    label_to_idx=label_to_idx,
    max_length_s=w2v_cfg['max_length_s'],
)

print(f'Waveform datasets: train={len(train_wav_ds)}, val={len(val_wav_ds)}, test={len(test_wav_ds)}')

# DataLoaders (smaller batch for wav2vec2 VRAM)
g_w2v = torch.Generator()
g_w2v.manual_seed(SEED)

w2v_batch_size = w2v_cfg['training']['batch_size']

train_wav_loader = torch.utils.data.DataLoader(
    train_wav_ds, batch_size=w2v_batch_size, shuffle=True,
    num_workers=2, worker_init_fn=seed_worker, generator=g_w2v, pin_memory=True,
)
val_wav_loader = torch.utils.data.DataLoader(
    val_wav_ds, batch_size=w2v_batch_size, shuffle=False,
    num_workers=2, pin_memory=True,
)
test_wav_loader = torch.utils.data.DataLoader(
    test_wav_ds, batch_size=w2v_batch_size, shuffle=False,
    num_workers=2, pin_memory=True,
)

# Class weights (same distribution as CNN)
w2v_class_weights = compute_class_weights(train_labels_cnn, n_classes)
print(f'wav2vec2 class weights: {w2v_class_weights.tolist()}')

In [None]:
# Train wav2vec2
# Free CNN memory before loading wav2vec2
del cnn_model
torch.cuda.empty_cache()

w2v_model = AccentWav2Vec2(
    n_classes=n_classes,
    model_name=w2v_cfg['model_name'],
    freeze_feature_extractor=w2v_cfg['freeze_feature_extractor'],
)

w2v_checkpoint_dir = Path(config['output']['checkpoint_dir']) / 'wav2vec2'

w2v_training_config = TrainingConfig(
    learning_rate=w2v_cfg['training']['learning_rate'],
    batch_size=w2v_cfg['training']['batch_size'],
    n_epochs=w2v_cfg['training']['n_epochs'],
    patience=w2v_cfg['training']['patience'],
    device=DEVICE,
    seed=SEED,
    checkpoint_dir=w2v_checkpoint_dir,
    experiment_name='accent_wav2vec2',
    use_amp=w2v_cfg['training']['use_amp'],
)

print(f'Training wav2vec2: lr={w2v_training_config.learning_rate}, '
      f'epochs={w2v_training_config.n_epochs}, '
      f'patience={w2v_training_config.patience}')
print(f'VRAM before training: {torch.cuda.memory_allocated()/1e9:.2f} GB')

w2v_result = train_classifier(
    model=w2v_model,
    train_loader=train_wav_loader,
    val_loader=val_wav_loader,
    config=w2v_training_config,
    class_weights=w2v_class_weights,
)

print(f'\nwav2vec2 training complete:')
print(f'  Best epoch: {w2v_result.best_epoch}')
print(f'  Best val bal_acc: {w2v_result.best_val_bal_acc:.4f}')
print(f'  Total epochs: {w2v_result.total_epochs_run}')
print(f'  Checkpoint: {w2v_result.best_checkpoint_path}')

# Plot training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.plot(w2v_result.train_losses, label='Train Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('wav2vec2 Training Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

ax2.plot(w2v_result.val_bal_accs, label='Val Balanced Accuracy', color='orange')
ax2.axhline(y=1.0/n_classes, color='red', linestyle='--', label=f'Chance ({1.0/n_classes:.2f})')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Balanced Accuracy')
ax2.set_title('wav2vec2 Validation Balanced Accuracy')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(figures_dir / 'wav2vec2_training_curves.png', dpi=150)
plt.show()

In [None]:
# Evaluate wav2vec2 on test set
checkpoint_w2v = torch.load(w2v_result.best_checkpoint_path, map_location=DEVICE, weights_only=False)
w2v_model.load_state_dict(checkpoint_w2v['model_state_dict'])

w2v_eval = evaluate_classifier(
    model=w2v_model,
    test_loader=test_wav_loader,
    label_names=label_names,
    device=DEVICE,
    n_bootstrap=config['evaluation']['bootstrap_n_samples'],
)

print('=== WAV2VEC2 TEST EVALUATION ===')
print(f'Balanced Accuracy: {w2v_eval["balanced_accuracy"]:.4f} '
      f'(CI 95%: [{w2v_eval["ci_95_lower"]:.4f}, {w2v_eval["ci_95_upper"]:.4f}])')
print(f'F1 Macro: {w2v_eval["f1_macro"]:.4f}')
print(f'Chance level: {1.0/n_classes:.4f}')
print(f'\nPer-class recall:')
for name, recall in w2v_eval['per_class_recall'].items():
    print(f'  {name}: {recall:.4f}')

# Display confusion matrix
cm_w2v = np.array(w2v_eval['confusion_matrix'])
cm_w2v_norm = cm_w2v.astype(float) / cm_w2v.sum(axis=1, keepdims=True)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

sns.heatmap(cm_w2v, annot=True, fmt='d', xticklabels=label_names,
            yticklabels=label_names, cmap='Greens', ax=ax1)
ax1.set_title('wav2vec2 Confusion Matrix (counts)')
ax1.set_xlabel('Predicted')
ax1.set_ylabel('True')

sns.heatmap(cm_w2v_norm, annot=True, fmt='.2f', xticklabels=label_names,
            yticklabels=label_names, cmap='Greens', ax=ax2)
ax2.set_title('wav2vec2 Confusion Matrix (row-normalized recall)')
ax2.set_xlabel('Predicted')
ax2.set_ylabel('True')

plt.tight_layout()
plt.savefig(figures_dir / 'wav2vec2_confusion_matrix.png', dpi=150)
plt.show()

## Cross-Source Evaluation (Source Confound Check)

**Objetivo:** verificar se o classificador aprendeu *sotaque* ou *fonte*.  
- Treinamos na fonte A (CORAA-MUPE), testamos na fonte B (Common Voice) e vice-versa.  
- Se ambas as direcoes ficam em chance level -> classificador aprendeu source, nao accent.  
- Se ao menos uma direcao fica acima de chance -> sinal de accent e transferivel entre fontes.  

Usamos o CNN (mais rapido) para este cross-source check.

In [None]:
# Check that primary classifier is above chance before cross-source
chance = 1.0 / n_classes
best_primary = max(cnn_eval['balanced_accuracy'], w2v_eval['balanced_accuracy'])
if best_primary <= chance + 0.05:
    print(f'WARNING: Best primary classifier ({best_primary:.4f}) is at chance ({chance:.4f}).')
    print('Cross-source evaluation may not be meaningful.')

# Split combined entries by source
coraa_train = [e for e in train_entries if e.source == 'CORAA-MUPE']
coraa_test_split = [e for e in test_entries if e.source == 'CORAA-MUPE']
cv_train = [e for e in train_entries if e.source == 'CommonVoice-PT']
cv_test_split = [e for e in test_entries if e.source == 'CommonVoice-PT']

print(f'CORAA-MUPE: train={len(coraa_train)}, test={len(coraa_test_split)}')
print(f'CommonVoice-PT: train={len(cv_train)}, test={len(cv_test_split)}')

cross_source_results = {}

# Direction 1: Train on CORAA-MUPE, test on CommonVoice-PT
if len(coraa_train) > 0 and len(cv_test_split) > 0:
    print('\n--- Direction 1: Train CORAA-MUPE -> Test CommonVoice-PT ---')

    cs_train_ds = MelSpectrogramDataset(
        coraa_train, label_to_idx, n_mels=cnn_cfg['n_mels'], max_frames=cnn_cfg['max_frames'],
    )
    cs_test_ds = MelSpectrogramDataset(
        cv_test_split, label_to_idx, n_mels=cnn_cfg['n_mels'], max_frames=cnn_cfg['max_frames'],
    )

    # Validation: use same-source val split (never train data)
    coraa_val = [e for e in val_entries if e.source == 'CORAA-MUPE']
    if len(coraa_val) < 10:
        # Not enough val samples from this source — use full val split as fallback
        # This is safe: val is speaker-disjoint from train, just mixes sources
        coraa_val = list(val_entries)
        print(f'  Val fallback: using full val split ({len(coraa_val)} entries, mixed source)')

    cs_val_ds = MelSpectrogramDataset(
        coraa_val, label_to_idx, n_mels=cnn_cfg['n_mels'], max_frames=cnn_cfg['max_frames'],
    )

    g_cs1 = torch.Generator()
    g_cs1.manual_seed(SEED)

    cs_train_loader = torch.utils.data.DataLoader(
        cs_train_ds, batch_size=cnn_batch_size, shuffle=True,
        num_workers=2, worker_init_fn=seed_worker, generator=g_cs1, pin_memory=True,
    )
    cs_val_loader = torch.utils.data.DataLoader(
        cs_val_ds, batch_size=cnn_batch_size, shuffle=False, num_workers=2, pin_memory=True,
    )
    cs_test_loader = torch.utils.data.DataLoader(
        cs_test_ds, batch_size=cnn_batch_size, shuffle=False, num_workers=2, pin_memory=True,
    )

    cs_labels = [label_to_idx[e.accent] for e in coraa_train]
    cs_weights = compute_class_weights(cs_labels, n_classes)

    cs_model_1 = AccentCNN(n_classes=n_classes, n_mels=cnn_cfg['n_mels'], conv_channels=cnn_cfg['conv_channels'])
    cs_config_1 = TrainingConfig(
        learning_rate=cnn_cfg['training']['learning_rate'],
        batch_size=cnn_batch_size, n_epochs=cnn_cfg['training']['n_epochs'],
        patience=cnn_cfg['training']['patience'], device=DEVICE, seed=SEED,
        checkpoint_dir=Path(config['output']['checkpoint_dir']) / 'cross_source_coraa2cv',
        experiment_name='cross_source_coraa2cv', use_amp=cnn_cfg['training']['use_amp'],
    )

    cs_result_1 = train_classifier(cs_model_1, cs_train_loader, cs_val_loader, cs_config_1, cs_weights)

    checkpoint_cs1 = torch.load(cs_result_1.best_checkpoint_path, map_location=DEVICE, weights_only=False)
    cs_model_1.load_state_dict(checkpoint_cs1['model_state_dict'])

    cs_eval_1 = evaluate_classifier(cs_model_1, cs_test_loader, label_names, DEVICE)
    cross_source_results['coraa2cv'] = cs_eval_1

    print(f'CORAA->CV bal_acc: {cs_eval_1["balanced_accuracy"]:.4f} '
          f'(CI: [{cs_eval_1["ci_95_lower"]:.4f}, {cs_eval_1["ci_95_upper"]:.4f}])')

    del cs_model_1
    torch.cuda.empty_cache()
else:
    print('Skipping direction 1: insufficient data')

# Direction 2: Train on CommonVoice-PT, test on CORAA-MUPE
if len(cv_train) > 0 and len(coraa_test_split) > 0:
    print('\n--- Direction 2: Train CommonVoice-PT -> Test CORAA-MUPE ---')

    cs_train_ds2 = MelSpectrogramDataset(
        cv_train, label_to_idx, n_mels=cnn_cfg['n_mels'], max_frames=cnn_cfg['max_frames'],
    )
    cs_test_ds2 = MelSpectrogramDataset(
        coraa_test_split, label_to_idx, n_mels=cnn_cfg['n_mels'], max_frames=cnn_cfg['max_frames'],
    )

    # Validation: use same-source val split (never train data)
    cv_val = [e for e in val_entries if e.source == 'CommonVoice-PT']
    if len(cv_val) < 10:
        cv_val = list(val_entries)
        print(f'  Val fallback: using full val split ({len(cv_val)} entries, mixed source)')

    cs_val_ds2 = MelSpectrogramDataset(
        cv_val, label_to_idx, n_mels=cnn_cfg['n_mels'], max_frames=cnn_cfg['max_frames'],
    )

    g_cs2 = torch.Generator()
    g_cs2.manual_seed(SEED)

    cs_train_loader2 = torch.utils.data.DataLoader(
        cs_train_ds2, batch_size=cnn_batch_size, shuffle=True,
        num_workers=2, worker_init_fn=seed_worker, generator=g_cs2, pin_memory=True,
    )
    cs_val_loader2 = torch.utils.data.DataLoader(
        cs_val_ds2, batch_size=cnn_batch_size, shuffle=False, num_workers=2, pin_memory=True,
    )
    cs_test_loader2 = torch.utils.data.DataLoader(
        cs_test_ds2, batch_size=cnn_batch_size, shuffle=False, num_workers=2, pin_memory=True,
    )

    cs_labels2 = [label_to_idx[e.accent] for e in cv_train]
    cs_weights2 = compute_class_weights(cs_labels2, n_classes)

    cs_model_2 = AccentCNN(n_classes=n_classes, n_mels=cnn_cfg['n_mels'], conv_channels=cnn_cfg['conv_channels'])
    cs_config_2 = TrainingConfig(
        learning_rate=cnn_cfg['training']['learning_rate'],
        batch_size=cnn_batch_size, n_epochs=cnn_cfg['training']['n_epochs'],
        patience=cnn_cfg['training']['patience'], device=DEVICE, seed=SEED,
        checkpoint_dir=Path(config['output']['checkpoint_dir']) / 'cross_source_cv2coraa',
        experiment_name='cross_source_cv2coraa', use_amp=cnn_cfg['training']['use_amp'],
    )

    cs_result_2 = train_classifier(cs_model_2, cs_train_loader2, cs_val_loader2, cs_config_2, cs_weights2)

    checkpoint_cs2 = torch.load(cs_result_2.best_checkpoint_path, map_location=DEVICE, weights_only=False)
    cs_model_2.load_state_dict(checkpoint_cs2['model_state_dict'])

    cs_eval_2 = evaluate_classifier(cs_model_2, cs_test_loader2, label_names, DEVICE)
    cross_source_results['cv2coraa'] = cs_eval_2

    print(f'CV->CORAA bal_acc: {cs_eval_2["balanced_accuracy"]:.4f} '
          f'(CI: [{cs_eval_2["ci_95_lower"]:.4f}, {cs_eval_2["ci_95_upper"]:.4f}])')

    del cs_model_2
    torch.cuda.empty_cache()
else:
    print('Skipping direction 2: insufficient data')

# Interpretation
print(f'\n=== CROSS-SOURCE SUMMARY ===')
print(f'Chance level: {chance:.4f}')
for direction, eval_result in cross_source_results.items():
    ba = eval_result['balanced_accuracy']
    above_chance = ba > chance + 0.05  # 5pp above chance
    status = 'ABOVE CHANCE (accent signal transfers)' if above_chance else 'AT CHANCE (possible source confound)'
    print(f'  {direction}: bal_acc={ba:.4f} -> {status}')

if all(r['balanced_accuracy'] <= chance + 0.05 for r in cross_source_results.values()):
    print('\nWARNING: Both directions at chance. Classifier may have learned source, not accent.')
else:
    print('\nAt least one direction shows transfer. Accent signal appears generalizable across sources.')

## Ablation Summary

Comparison table: CNN vs wav2vec2, with balanced accuracy, CI 95%, F1 macro, and cross-source results.  
All metrics follow the protocol: balanced accuracy (primary), CI 95% (bootstrap, 1000 samples).

In [None]:
# Build comparison table
chance = 1.0 / n_classes

comparison_data = {
    'Model': ['CNN (mel-spectrogram)', 'wav2vec2-base'],
    'Balanced Accuracy': [
        f'{cnn_eval["balanced_accuracy"]:.4f}',
        f'{w2v_eval["balanced_accuracy"]:.4f}',
    ],
    'CI 95% Lower': [
        f'{cnn_eval["ci_95_lower"]:.4f}',
        f'{w2v_eval["ci_95_lower"]:.4f}',
    ],
    'CI 95% Upper': [
        f'{cnn_eval["ci_95_upper"]:.4f}',
        f'{w2v_eval["ci_95_upper"]:.4f}',
    ],
    'F1 Macro': [
        f'{cnn_eval["f1_macro"]:.4f}',
        f'{w2v_eval["f1_macro"]:.4f}',
    ],
    'Best Epoch': [
        cnn_result.best_epoch,
        w2v_result.best_epoch,
    ],
    'Total Epochs': [
        cnn_result.total_epochs_run,
        w2v_result.total_epochs_run,
    ],
}

comparison_df = pd.DataFrame(comparison_data)
print('=== ABLATION: CNN vs wav2vec2 ===')
print(f'Chance level: {chance:.4f}')
print()
print(comparison_df.to_string(index=False))

# Check if CIs overlap (cannot claim one is better)
cnn_ci = (cnn_eval['ci_95_lower'], cnn_eval['ci_95_upper'])
w2v_ci = (w2v_eval['ci_95_lower'], w2v_eval['ci_95_upper'])

overlap = cnn_ci[0] <= w2v_ci[1] and w2v_ci[0] <= cnn_ci[1]
if overlap:
    print('\nCIs overlap -> cannot claim one model is superior.')
else:
    winner = 'wav2vec2' if w2v_eval['balanced_accuracy'] > cnn_eval['balanced_accuracy'] else 'CNN'
    print(f'\nCIs do NOT overlap -> {winner} is significantly better.')

# Cross-source results table
if cross_source_results:
    print('\n=== CROSS-SOURCE EVALUATION ===')
    for direction, result in cross_source_results.items():
        print(f'  {direction}: bal_acc={result["balanced_accuracy"]:.4f} '
              f'(CI: [{result["ci_95_lower"]:.4f}, {result["ci_95_upper"]:.4f}])')

In [None]:
# Save full report as JSON with all metrics, configs, and hashes
import subprocess
from datetime import datetime

# Git commit hash
try:
    commit_hash = subprocess.check_output(
        ['git', 'rev-parse', 'HEAD'], text=True
    ).strip()
except Exception:
    commit_hash = 'unknown'

# SHA-256 of the combined manifest
manifest_sha256 = compute_file_hash(COMBINED_MANIFEST_PATH) if COMBINED_MANIFEST_PATH.exists() else 'N/A'

report = {
    'experiment': config['experiment']['name'],
    'date': datetime.now().isoformat(),
    'commit_hash': commit_hash,
    'seed': SEED,
    'environment': {
        'python_version': sys.version,
        'torch_version': torch.__version__,
        'cuda_version': torch.version.cuda if torch.cuda.is_available() else None,
        'gpu_name': torch.cuda.get_device_name(0) if torch.cuda.is_available() else None,
    },
    'dataset': {
        'name': config['dataset']['name'],
        'combined_manifest_sha256': manifest_sha256,
        'total_entries': len(combined_entries),
        'total_speakers': len({e.speaker_id for e in combined_entries}),
        'n_classes': n_classes,
        'label_names': label_names,
        'per_source': dict(Counter(e.source for e in combined_entries)),
        'per_accent': dict(Counter(e.accent for e in combined_entries)),
    },
    'splits': {
        'method': config['splits']['method'],
        'ratios': config['splits']['ratios'],
        'seed': config['splits']['seed'],
        'train_utterances': len(train_entries),
        'val_utterances': len(val_entries),
        'test_utterances': len(test_entries),
        'train_speakers': len(split_info.train_speakers),
        'val_speakers': len(split_info.val_speakers),
        'test_speakers': len(split_info.test_speakers),
    },
    'confounds': [
        {
            'test': r.test_name,
            'variables': f'{r.variable_a} x {r.variable_b}',
            'statistic': r.statistic,
            'p_value': r.p_value,
            'effect_size': r.effect_size,
            'effect_size_name': r.effect_size_name,
            'is_blocking': r.is_blocking,
            'is_significant': r.is_significant,
            'interpretation': r.interpretation,
        }
        for r in confound_results
    ],
    'cnn': {
        'config': {
            'n_mels': cnn_cfg['n_mels'],
            'max_frames': cnn_cfg['max_frames'],
            'conv_channels': cnn_cfg['conv_channels'],
            'learning_rate': cnn_cfg['training']['learning_rate'],
            'batch_size': cnn_cfg['training']['batch_size'],
            'n_epochs': cnn_cfg['training']['n_epochs'],
            'patience': cnn_cfg['training']['patience'],
        },
        'training': {
            'best_epoch': cnn_result.best_epoch,
            'best_val_bal_acc': cnn_result.best_val_bal_acc,
            'total_epochs_run': cnn_result.total_epochs_run,
            'checkpoint_path': str(cnn_result.best_checkpoint_path),
        },
        'evaluation': cnn_eval,
    },
    'wav2vec2': {
        'config': {
            'model_name': w2v_cfg['model_name'],
            'freeze_feature_extractor': w2v_cfg['freeze_feature_extractor'],
            'max_length_s': w2v_cfg['max_length_s'],
            'learning_rate': w2v_cfg['training']['learning_rate'],
            'batch_size': w2v_cfg['training']['batch_size'],
            'n_epochs': w2v_cfg['training']['n_epochs'],
            'patience': w2v_cfg['training']['patience'],
        },
        'training': {
            'best_epoch': w2v_result.best_epoch,
            'best_val_bal_acc': w2v_result.best_val_bal_acc,
            'total_epochs_run': w2v_result.total_epochs_run,
            'checkpoint_path': str(w2v_result.best_checkpoint_path),
        },
        'evaluation': w2v_eval,
    },
    'cross_source': cross_source_results,
    'chance_level': chance,
}

# Save report
report_dir = Path(config['output']['report_dir'])
report_dir.mkdir(parents=True, exist_ok=True)
report_path = Path(config['output']['report_json'])

with open(report_path, 'w') as f:
    json.dump(report, f, indent=2, default=str)

print(f'Report saved to: {report_path}')
print(f'Combined manifest SHA-256: {manifest_sha256}')
print(f'Commit hash: {commit_hash}')
print(f'\n=== EXPERIMENT COMPLETE ===')
print(f'CNN bal_acc: {cnn_eval["balanced_accuracy"]:.4f} '
      f'(CI: [{cnn_eval["ci_95_lower"]:.4f}, {cnn_eval["ci_95_upper"]:.4f}])')
print(f'wav2vec2 bal_acc: {w2v_eval["balanced_accuracy"]:.4f} '
      f'(CI: [{w2v_eval["ci_95_lower"]:.4f}, {w2v_eval["ci_95_upper"]:.4f}])')
print(f'Chance level: {chance:.4f}')