# RespiraHub — Trial 6: Dual-Backbone Ensemble (Wav2Vec2 + HeAR)

**Insight:** Wav2Vec2 (speech) dan HeAR (health) dapet AUROC mirip (~0.72) tapi kemungkinan capture features yang beda. Combine = complementary signal.

**Approach:**
1. Extract Wav2Vec2 embeddings (frozen, 768-dim) dari 2s segments
2. Extract HeAR embeddings (frozen, 1024-dim) dari 2s segments — udah ada
3. Concatenate → 1792-dim per segment
4. Train MLP classifier di atas combined representation

**Why this might work:**
- Wav2Vec2: good at temporal speech patterns, learned from 960h Librispeech
- HeAR: good at health acoustic events, learned from 313M health clips
- Different pre-training = different feature spaces = complementary

**Target:** 0.75+ (beat DREAM Challenge winner 0.743)

---
## Cell 1: Setup

In [None]:
import os, json, random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torchaudio
import matplotlib.pyplot as plt
from collections import Counter
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score, roc_curve
from torch.utils.data import Dataset, DataLoader
from transformers import Wav2Vec2Model, AutoModel
from tqdm import tqdm
import importlib
import warnings
warnings.filterwarnings('ignore')

# HeAR preprocessing
audio_utils = importlib.import_module('hear.python.data_processing.audio_utils')
preprocess_audio = audio_utils.preprocess_audio

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

if torch.backends.mps.is_available():
    DEVICE = torch.device('mps')
elif torch.cuda.is_available():
    DEVICE = torch.device('cuda')
else:
    DEVICE = torch.device('cpu')

print(f'PyTorch {torch.__version__}, Device: {DEVICE}')

---
## Cell 2: Load Metadata + Build 2s Segments

In [None]:
AUDIO_DIR = '/Users/aida/code/development/tb-datasets/data/solicited/'
CLINICAL_PATH = '/Users/aida/code/development/tb-datasets/data/metadata/CODA_TB_Clinical_Meta_Info.csv'
ADDITIONAL_PATH = '/Users/aida/code/development/tb-datasets/data/metadata/CODA_TB_additional_variables_train.csv'
SOLICITED_PATH = '/Users/aida/code/development/tb-datasets/data/metadata/CODA_TB_Solicited_Meta_Info.csv'

clinical = pd.read_csv(CLINICAL_PATH)
additional = pd.read_csv(ADDITIONAL_PATH)
solicited = pd.read_csv(SOLICITED_PATH)

meta = clinical.merge(additional, on='participant', how='left')
meta['label'] = meta['tb_status'].astype(int)

df = solicited.merge(meta, on='participant', how='left')
df['filepath'] = df['filename'].apply(lambda f: os.path.join(AUDIO_DIR, f))
df['file_exists'] = df['filepath'].apply(os.path.exists)
df = df[df['file_exists']].reset_index(drop=True)

print(f'Participants: {meta["label"].count()}, TB+: {meta["label"].sum()}')
print(f'Cough files: {len(df)} from {df["participant"].nunique()} participants')

# === Segmentation: concatenate all + split 2s (same as Trial 5) ===
TARGET_SR = 16000
SEGMENT_SEC = 2.0
SEGMENT_SAMPLES = int(TARGET_SR * SEGMENT_SEC)  # 32000
GAP_SAMPLES = int(TARGET_SR * 0.05)

def load_audio(filepath, target_sr=16000):
    try:
        w, sr = torchaudio.load(filepath)
        if w.shape[0] > 1:
            w = w.mean(dim=0, keepdim=True)
        if sr != target_sr:
            w = torchaudio.transforms.Resample(sr, target_sr)(w)
        return w.squeeze(0)
    except:
        return None

grouped = df.groupby('participant').agg(
    label=('label', 'first'),
    filepaths=('filepath', list),
).reset_index()

all_segments = []
all_labels = []
all_pids = []

for _, row in tqdm(grouped.iterrows(), total=len(grouped), desc='Building 2s segments'):
    pid = row['participant']
    label = row['label']
    waveforms = []
    for fp in row['filepaths']:
        w = load_audio(fp, TARGET_SR)
        if w is not None and len(w) > 0:
            waveforms.append(w)
    if not waveforms:
        continue
    parts = []
    gap = torch.zeros(GAP_SAMPLES)
    for i, w in enumerate(waveforms):
        parts.append(w)
        if i < len(waveforms) - 1:
            parts.append(gap)
    combined = torch.cat(parts)
    n_segments = max(1, len(combined) // SEGMENT_SAMPLES)
    for i in range(n_segments):
        start = i * SEGMENT_SAMPLES
        end = start + SEGMENT_SAMPLES
        seg = combined[start:end]
        if len(seg) < SEGMENT_SAMPLES:
            pad_total = SEGMENT_SAMPLES - len(seg)
            pad_left = pad_total // 2
            pad_right = pad_total - pad_left
            seg = torch.nn.functional.pad(seg, (pad_left, pad_right))
        all_segments.append(seg)
        all_labels.append(label)
        all_pids.append(pid)

all_labels = np.array(all_labels)
all_pids = np.array(all_pids)
counts = list(Counter(all_pids).values())

print(f'\nSegments: {len(all_segments)} @ 2s')
print(f'Patients: {len(np.unique(all_pids))}')
print(f'Seg/patient: mean={np.mean(counts):.1f}, min={min(counts)}, max={max(counts)}')

---
## Cell 3: Load Both Models

In [None]:
# === Wav2Vec2 ===
print('Loading Wav2Vec2-base...')
w2v_model = Wav2Vec2Model.from_pretrained('facebook/wav2vec2-base')
w2v_model = w2v_model.to(DEVICE)
w2v_model.eval()
for param in w2v_model.parameters():
    param.requires_grad = False
print(f'  Wav2Vec2: {sum(p.numel() for p in w2v_model.parameters())/1e6:.1f}M params (frozen)')

# Test
with torch.no_grad():
    test = torch.randn(1, SEGMENT_SAMPLES).to(DEVICE)
    w2v_out = w2v_model(test).last_hidden_state.mean(dim=1)
    W2V_DIM = w2v_out.shape[-1]
print(f'  Wav2Vec2 embedding dim: {W2V_DIM}')

# === HeAR ===
print('\nLoading HeAR...')
hear_model = AutoModel.from_pretrained('google/hear-pytorch', trust_remote_code=True)
hear_model = hear_model.to(DEVICE)
hear_model.eval()
for param in hear_model.parameters():
    param.requires_grad = False
print(f'  HeAR: {sum(p.numel() for p in hear_model.parameters())/1e6:.1f}M params (frozen)')

# Test
with torch.no_grad():
    test_spec = preprocess_audio(torch.randn(1, SEGMENT_SAMPLES))
    hear_out = hear_model.forward(test_spec.to(DEVICE), return_dict=True, output_hidden_states=True)
    hear_emb = hear_out.last_hidden_state.mean(dim=1)
    HEAR_DIM = hear_emb.shape[-1]
print(f'  HeAR embedding dim: {HEAR_DIM}')

COMBINED_DIM = W2V_DIM + HEAR_DIM
print(f'\nCombined embedding: {W2V_DIM} + {HEAR_DIM} = {COMBINED_DIM}')

---
## Cell 4: Extract Dual Embeddings

In [None]:
EMBED_BATCH = 8  # smaller batch since 2 models
w2v_embeddings = []
hear_embeddings = []

print(f'Extracting dual embeddings for {len(all_segments)} segments...')
print(f'Wav2Vec2 ({W2V_DIM}-dim) + HeAR ({HEAR_DIM}-dim)')

for i in tqdm(range(0, len(all_segments), EMBED_BATCH), desc='Extracting'):
    batch_segs = all_segments[i:i + EMBED_BATCH]
    batch_audio = torch.stack(batch_segs)
    
    with torch.no_grad():
        # Wav2Vec2: raw audio → mean pool
        w2v_out = w2v_model(batch_audio.to(DEVICE)).last_hidden_state
        w2v_emb = w2v_out.mean(dim=1).cpu()
        
        # HeAR: raw audio → spectrogram → mean pool
        batch_spec = preprocess_audio(batch_audio)
        hear_out = hear_model.forward(
            batch_spec.to(DEVICE),
            return_dict=True,
            output_hidden_states=True
        )
        hear_emb = hear_out.last_hidden_state.mean(dim=1).cpu()
    
    w2v_embeddings.append(w2v_emb)
    hear_embeddings.append(hear_emb)

w2v_embeddings = torch.cat(w2v_embeddings, dim=0)
hear_embeddings = torch.cat(hear_embeddings, dim=0)

# Concatenate
all_embeddings = torch.cat([w2v_embeddings, hear_embeddings], dim=1)

print(f'\nWav2Vec2 embeddings: {w2v_embeddings.shape}')
print(f'HeAR embeddings: {hear_embeddings.shape}')
print(f'Combined embeddings: {all_embeddings.shape}')
print(f'Memory: {all_embeddings.nbytes / 1024 / 1024:.1f} MB')

# Save
torch.save({
    'w2v_embeddings': w2v_embeddings,
    'hear_embeddings': hear_embeddings,
    'combined_embeddings': all_embeddings,
    'labels': all_labels,
    'pids': all_pids,
}, 'dual_embeddings.pt')
print('Saved dual_embeddings.pt')

---
## Cell 5: Classifier + Config

In [None]:
EMBED_DIM = all_embeddings.shape[1]  # auto-detect

class DualBackboneClassifier(nn.Module):
    """Classifier on concatenated Wav2Vec2 + HeAR embeddings."""
    def __init__(self, embed_dim=EMBED_DIM, dropout=0.3):
        super().__init__()
        self.head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Dropout(dropout),
            nn.Linear(embed_dim, 256),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(256, 64),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(64, 1),
        )
    def forward(self, x):
        return self.head(x).squeeze(-1)

class EmbeddingDataset(Dataset):
    def __init__(self, embeddings, labels, pids):
        self.embeddings = embeddings
        self.labels = torch.tensor(labels, dtype=torch.float32)
        self.pids = pids
    def __len__(self): return len(self.embeddings)
    def __getitem__(self, idx):
        return {'emb': self.embeddings[idx], 'label': self.labels[idx], 'pid': self.pids[idx]}

BATCH_SIZE = 64
LR = 5e-4
EPOCHS = 50
PATIENCE = 7
N_FOLDS = 10

m = DualBackboneClassifier()
n_params = sum(p.numel() for p in m.parameters())
print(f'=== Trial 6 Config ===')
print(f'Input dim: {EMBED_DIM} (Wav2Vec2 {W2V_DIM} + HeAR {HEAR_DIM})')
print(f'Classifier: {n_params:,} params ({n_params/1e3:.1f}K)')
print(f'Architecture: {EMBED_DIM}→256→64→1 (3-layer MLP)')
print(f'Batch: {BATCH_SIZE}, LR: {LR}, Epochs: {EPOCHS}, Patience: {PATIENCE}')
del m

---
## Cell 6: Training Function

In [None]:
def train_one_fold(fold_num, tr_emb, tr_lab, tr_pid, va_emb, va_lab, va_pid):
    train_ds = EmbeddingDataset(tr_emb, tr_lab, tr_pid)
    val_ds = EmbeddingDataset(va_emb, va_lab, va_pid)
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)
    
    model = DualBackboneClassifier().to(DEVICE)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01)
    criterion = nn.BCEWithLogitsLoss()
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=1e-6)
    
    best_auroc = 0
    best_patient_logits = {}
    patience_counter = 0
    
    for epoch in range(EPOCHS):
        model.train()
        train_loss = 0
        for batch in train_loader:
            emb = batch['emb'].to(DEVICE)
            labels = batch['label'].to(DEVICE)
            optimizer.zero_grad()
            loss = criterion(model(emb), labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss /= len(train_loader)
        scheduler.step()
        
        # Validation
        model.eval()
        seg_probs, seg_labels, seg_pids = [], [], []
        with torch.no_grad():
            for batch in val_loader:
                emb = batch['emb'].to(DEVICE)
                probs = torch.sigmoid(model(emb)).cpu().numpy()
                seg_probs.extend(probs)
                seg_labels.extend(batch['label'].numpy())
                seg_pids.extend(batch['pid'])
        
        pt_p, pt_l = {}, {}
        for pid, prob, lab in zip(seg_pids, seg_probs, seg_labels):
            pt_p.setdefault(pid, []).append(prob)
            pt_l[pid] = lab
        
        yt = np.array([pt_l[p] for p in pt_p])
        yp = np.array([np.mean(v) for v in pt_p.values()])
        auroc = roc_auc_score(yt, yp) if len(np.unique(yt)) > 1 else 0.5
        
        improved = ''
        if auroc > best_auroc:
            best_auroc = auroc
            best_patient_logits = {pid: np.mean(v) for pid, v in pt_p.items()}
            os.makedirs('checkpoints_t6', exist_ok=True)
            torch.save(model.state_dict(), f'checkpoints_t6/dual_fold{fold_num}.pt')
            patience_counter = 0
            improved = ' *'
        else:
            patience_counter += 1
        
        if (epoch + 1) % 5 == 0 or improved or patience_counter >= PATIENCE:
            print(f'  Epoch {epoch+1}/{EPOCHS} — loss: {train_loss:.4f}, AUROC: {auroc:.4f}{improved}')
        
        if patience_counter >= PATIENCE:
            print(f'  Early stopping at epoch {epoch+1}')
            break
    
    del model, optimizer
    return best_auroc, best_patient_logits

print('Training function ready.')

---
## Cell 7: Run 10-Fold CV

In [None]:
unique_pids = np.unique(all_pids)
pid_labels = np.array([all_labels[all_pids == pid][0] for pid in unique_pids])

n_folds_actual = min(N_FOLDS, int(min(Counter(pid_labels).values()) * 0.8))
skf = StratifiedKFold(n_splits=n_folds_actual, shuffle=True, random_state=42)

fold_aurocs = []
all_patient_logits = {}
all_patient_labels = {}

print(f'=== TRIAL 6: Dual-Backbone Ensemble ===')
print(f'Backbones: Wav2Vec2 ({W2V_DIM}d) + HeAR ({HEAR_DIM}d) = {COMBINED_DIM}d')
print(f'Classifier: 3-layer MLP')
print(f'Segments: {len(all_segments)} @ 2s')
print(f'Patients: {len(unique_pids)}')
print(f'Folds: {n_folds_actual}, Device: {DEVICE}')
print(f'\n--- Benchmarks ---')
print(f'DREAM winner:  0.743')
print(f'Trial 1 (W2V): 0.718 (re-run baseline)')
print(f'Trial 5 (HeAR): 0.719')
print(f'\nStarting training...\n')

for fold, (train_idx, val_idx) in enumerate(skf.split(unique_pids, pid_labels)):
    print(f'=== Fold {fold+1}/{n_folds_actual} ===')
    
    train_pids_set = set(unique_pids[train_idx])
    val_pids_set = set(unique_pids[val_idx])
    
    tr_mask = np.array([pid in train_pids_set for pid in all_pids])
    va_mask = ~tr_mask
    
    tr_emb = all_embeddings[tr_mask]
    tr_lab = all_labels[tr_mask]
    tr_pid = all_pids[tr_mask]
    va_emb = all_embeddings[va_mask]
    va_lab = all_labels[va_mask]
    va_pid = all_pids[va_mask]
    
    print(f'  Train: {len(tr_emb)} embeddings ({len(train_pids_set)} patients)')
    print(f'  Val:   {len(va_emb)} embeddings ({len(val_pids_set)} patients)')
    
    auroc, patient_logits = train_one_fold(
        fold+1, tr_emb, tr_lab, tr_pid, va_emb, va_lab, va_pid
    )
    
    fold_aurocs.append(auroc)
    all_patient_logits.update(patient_logits)
    for pid in val_pids_set:
        all_patient_labels[pid] = pid_labels[unique_pids == pid][0]
    
    print(f'  \u2705 Fold {fold+1} best AUROC: {auroc:.4f}\n')

print('=' * 60)
print(f'TRIAL 6 RESULT (Dual-Backbone Ensemble)')
print(f'Mean AUROC: {np.mean(fold_aurocs):.4f} +/- {np.std(fold_aurocs):.4f}')
print(f'Per-fold: {[f"{a:.3f}" for a in fold_aurocs]}')
print(f'\n--- Full Comparison ---')
print(f'DREAM winner:       0.743')
print(f'Trial 1 (W2V 3s):   0.718')
print(f'Trial 5 (HeAR 2s):  0.719')
print(f'Trial 6 (W2V+HeAR): {np.mean(fold_aurocs):.4f} +/- {np.std(fold_aurocs):.4f}')

beat_single = np.mean(fold_aurocs) > 0.719
beat_dream = np.mean(fold_aurocs) > 0.743
msg = '\u2705 BEAT DREAM CHALLENGE WINNER!' if beat_dream else ('\u2705 Beat both single backbones!' if beat_single else '\u26a0\ufe0f Did not beat single backbones.')
print(f'\n{msg}')

---
## Cell 8: ROC + Threshold

In [None]:
pids_order = sorted(all_patient_logits.keys())
y_true = np.array([all_patient_labels[p] for p in pids_order])
y_prob = np.array([all_patient_logits[p] for p in pids_order])

fpr_t6, tpr_t6, _ = roc_curve(y_true, y_prob)
auroc_t6 = roc_auc_score(y_true, y_prob)

plt.figure(figsize=(8, 7))
plt.plot(fpr_t6, tpr_t6, 'r-', lw=2.5, label=f'Trial 6 Ensemble ({auroc_t6:.3f})')
plt.plot([0,1],[0,1],'k--', alpha=0.2)
plt.axhspan(0.90, 1.0, xmin=0, xmax=0.30, alpha=0.08, color='green', label='WHO TPP zone')
plt.axhline(0.90, color='r', ls=':', alpha=0.3)
plt.axvline(0.30, color='g', ls=':', alpha=0.3)
plt.xlabel('FPR (1 - Specificity)', fontsize=12)
plt.ylabel('TPR (Sensitivity)', fontsize=12)
plt.title('RespiraHub Trial 6 \u2014 Dual-Backbone Ensemble', fontsize=14)
plt.legend(fontsize=11); plt.grid(alpha=0.15)
plt.tight_layout()
plt.savefig('roc_trial6.png', dpi=150)
plt.show()

In [None]:
print(f'{"Thresh":>7} {"Sens":>7} {"Spec":>7} {"PPV":>7} {"NPV":>7}')
print('-' * 40)

best_t, best_j = 0.5, -1

for t in np.arange(0.15, 0.85, 0.05):
    pred = (y_prob >= t).astype(int)
    tp = np.sum((pred == 1) & (y_true == 1))
    tn = np.sum((pred == 0) & (y_true == 0))
    fp = np.sum((pred == 1) & (y_true == 0))
    fn = np.sum((pred == 0) & (y_true == 1))
    sens = tp/(tp+fn) if (tp+fn) else 0
    spec = tn/(tn+fp) if (tn+fp) else 0
    ppv = tp/(tp+fp) if (tp+fp) else 0
    npv = tn/(tn+fn) if (tn+fn) else 0
    flag = ' \u2705 WHO' if (sens >= 0.90 and spec >= 0.70) else ''
    j = sens + spec - 1
    if j > best_j: best_j, best_t = j, t
    print(f'{t:>7.2f} {sens:>7.3f} {spec:>7.3f} {ppv:>7.3f} {npv:>7.3f}{flag}')

print(f'\nBest Youden threshold: {best_t:.2f}')

---
## Cell 9: Save

In [None]:
results_df = pd.DataFrame({
    'participant': pids_order,
    'true_label': y_true,
    'predicted_prob': y_prob,
})
results_df.to_csv('patient_predictions_trial6.csv', index=False)

summary = {
    'trial': 6,
    'approach': 'Dual-backbone ensemble: Wav2Vec2 + HeAR concatenated embeddings',
    'backbones': ['facebook/wav2vec2-base', 'google/hear-pytorch'],
    'embedding_dims': {'wav2vec2': W2V_DIM, 'hear': HEAR_DIM, 'combined': COMBINED_DIM},
    'n_participants': len(pids_order),
    'n_segments': len(all_segments),
    'segment_sec': SEGMENT_SEC,
    'n_folds': n_folds_actual,
    'auroc_mean': round(float(np.mean(fold_aurocs)), 4),
    'auroc_std': round(float(np.std(fold_aurocs)), 4),
    'auroc_per_fold': [round(float(a), 4) for a in fold_aurocs],
    'best_threshold': round(float(best_t), 2),
    'device': str(DEVICE),
}
with open('training_summary_trial6.json', 'w') as f:
    json.dump(summary, f, indent=2)

print('Saved:')
print('  patient_predictions_trial6.csv')
print('  training_summary_trial6.json')
print('  dual_embeddings.pt')
print('  checkpoints_t6/dual_fold*.pt')
print('  roc_trial6.png')
print()
print('=' * 60)
print('TRIAL 6 COMPLETE')
print('=' * 60)
print(f'Trial 1 (W2V only):   0.718')
print(f'Trial 5 (HeAR only):  0.719')
print(f'Trial 6 (Ensemble):   {np.mean(fold_aurocs):.4f} +/- {np.std(fold_aurocs):.4f}')
print(f'DREAM winner:         0.743')
print(f'\nDelta vs best single: {np.mean(fold_aurocs) - 0.719:+.4f}')
print(f'\nThis is the key result for the benchmark paper.')