# RespiraHub — Trial 1 (Reproduction)

**Goal:** Reproduce AUROC ~0.733 from original Trial 1.

**Config (exact match):**
- Model: Wav2Vec2-base, feature_extractor frozen, transformer trainable
- Head: Dropout(0.5) → Linear(768, 1)
- Segmentation: Concatenate all coughs per patient + split 3s
- 10-fold stratified CV, patient-level split
- Batch 4, effective batch 32 (grad accum 8), LR 3e-5, 5 epochs
- All seeds fixed to 42

In [None]:
import os, json, random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torchaudio
import matplotlib.pyplot as plt
from collections import Counter
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score, roc_curve
from torch.utils.data import Dataset, DataLoader
from transformers import Wav2Vec2Model
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# === FIXED SEEDS EVERYWHERE ===
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
# MPS deterministic
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

if torch.backends.mps.is_available():
    DEVICE = torch.device('mps')
elif torch.cuda.is_available():
    DEVICE = torch.device('cuda')
else:
    DEVICE = torch.device('cpu')

print(f'PyTorch {torch.__version__}, Device: {DEVICE}')
print(f'Seed: {SEED}')

---
## Cell 2: Load Data + Segmentation (3s, concatenate all)

In [None]:
AUDIO_DIR = '/Users/aida/code/development/tb-datasets/data/solicited/'
CLINICAL_PATH = '/Users/aida/code/development/tb-datasets/data/metadata/CODA_TB_Clinical_Meta_Info.csv'
ADDITIONAL_PATH = '/Users/aida/code/development/tb-datasets/data/metadata/CODA_TB_additional_variables_train.csv'
SOLICITED_PATH = '/Users/aida/code/development/tb-datasets/data/metadata/CODA_TB_Solicited_Meta_Info.csv'

clinical = pd.read_csv(CLINICAL_PATH)
additional = pd.read_csv(ADDITIONAL_PATH)
solicited = pd.read_csv(SOLICITED_PATH)

meta = clinical.merge(additional, on='participant', how='left')
meta['label'] = meta['tb_status'].astype(int)

df = solicited.merge(meta, on='participant', how='left')
df['filepath'] = df['filename'].apply(lambda f: os.path.join(AUDIO_DIR, f))
df['file_exists'] = df['filepath'].apply(os.path.exists)
df = df[df['file_exists']].reset_index(drop=True)

print(f'Participants: {meta["label"].count()}, TB+: {meta["label"].sum()}')
print(f'Cough files: {len(df)} from {df["participant"].nunique()} participants')

# === Segmentation: concatenate all + split 3s ===
TARGET_SR = 16000
SEGMENT_SEC = 3.0
SEGMENT_SAMPLES = int(TARGET_SR * SEGMENT_SEC)  # 48000
GAP_SAMPLES = int(TARGET_SR * 0.05)  # 50ms gap

def load_audio(filepath, target_sr=16000):
    try:
        w, sr = torchaudio.load(filepath)
        if w.shape[0] > 1:
            w = w.mean(dim=0, keepdim=True)
        if sr != target_sr:
            w = torchaudio.transforms.Resample(sr, target_sr)(w)
        return w.squeeze(0)
    except:
        return None

grouped = df.groupby('participant').agg(
    label=('label', 'first'),
    filepaths=('filepath', list),
).reset_index()

all_segments = []
all_labels = []
all_pids = []

for _, row in tqdm(grouped.iterrows(), total=len(grouped), desc='Building 3s segments'):
    pid = row['participant']
    label = row['label']
    waveforms = []
    for fp in row['filepaths']:
        w = load_audio(fp, TARGET_SR)
        if w is not None and len(w) > 0:
            waveforms.append(w)
    if not waveforms:
        continue
    # Concatenate all coughs with 50ms gap
    parts = []
    gap = torch.zeros(GAP_SAMPLES)
    for i, w in enumerate(waveforms):
        parts.append(w)
        if i < len(waveforms) - 1:
            parts.append(gap)
    combined = torch.cat(parts)
    # Split into 3s segments
    n_segments = max(1, len(combined) // SEGMENT_SAMPLES)
    for i in range(n_segments):
        start = i * SEGMENT_SAMPLES
        end = start + SEGMENT_SAMPLES
        seg = combined[start:end]
        if len(seg) < SEGMENT_SAMPLES:
            pad_total = SEGMENT_SAMPLES - len(seg)
            pad_left = pad_total // 2
            pad_right = pad_total - pad_left
            seg = torch.nn.functional.pad(seg, (pad_left, pad_right))
        all_segments.append(seg)
        all_labels.append(label)
        all_pids.append(pid)

all_labels = np.array(all_labels)
all_pids = np.array(all_pids)
counts = list(Counter(all_pids).values())

print(f'\nTotal segments: {len(all_segments)} @ 3s')
print(f'Patients: {len(np.unique(all_pids))}')
print(f'Seg/patient: mean={np.mean(counts):.1f}, min={min(counts)}, max={max(counts)}')
print(f'1 segment only: {sum(1 for c in counts if c == 1)}')

---
## Cell 3: Model + Config

In [None]:
class CoughClassifier(nn.Module):
    def __init__(self, model_name='facebook/wav2vec2-base', dropout=0.5):
        super().__init__()
        self.wav2vec2 = Wav2Vec2Model.from_pretrained(model_name)
        # Freeze feature extractor only, keep transformer trainable
        self.wav2vec2.feature_extractor._freeze_parameters()
        hidden = self.wav2vec2.config.hidden_size  # 768
        self.head = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(hidden, 1),
        )
    
    def forward(self, x):
        out = self.wav2vec2(x).last_hidden_state  # (B, T, 768)
        pooled = out.mean(dim=1)  # (B, 768)
        return self.head(pooled).squeeze(-1)  # (B,)

class CoughDataset(Dataset):
    def __init__(self, segments, labels, pids):
        self.segments = segments
        self.labels = torch.tensor(labels, dtype=torch.float32)
        self.pids = pids
    def __len__(self): return len(self.segments)
    def __getitem__(self, idx):
        return {'audio': self.segments[idx], 'label': self.labels[idx], 'pid': self.pids[idx]}

BATCH_SIZE = 4
GRAD_ACCUM = 8
LR = 3e-5
EPOCHS = 5
N_FOLDS = 10

print(f'=== Trial 1 Config ===')
print(f'Batch: {BATCH_SIZE} (effective: {BATCH_SIZE * GRAD_ACCUM})')
print(f'LR: {LR}, Epochs: {EPOCHS}, Folds: {N_FOLDS}')
print(f'Segment: {SEGMENT_SEC}s = {SEGMENT_SAMPLES} samples')

m = CoughClassifier()
n_total = sum(p.numel() for p in m.parameters())
n_train = sum(p.numel() for p in m.parameters() if p.requires_grad)
print(f'Total params: {n_total/1e6:.1f}M, Trainable: {n_train/1e6:.1f}M')
del m

---
## Cell 4: Training Function

In [None]:
def seed_worker(worker_id):
    worker_seed = SEED + worker_id
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(SEED)

def train_one_fold(fold_num, train_segs, train_labs, train_pids,
                   val_segs, val_labs, val_pids):
    # Re-seed for each fold for reproducibility
    torch.manual_seed(SEED + fold_num)
    np.random.seed(SEED + fold_num)
    random.seed(SEED + fold_num)
    
    train_ds = CoughDataset(train_segs, train_labs, train_pids)
    val_ds = CoughDataset(val_segs, val_labs, val_pids)
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                              num_workers=0, worker_init_fn=seed_worker, generator=g)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    
    model = CoughClassifier().to(DEVICE)
    optimizer = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=LR, weight_decay=0.01
    )
    criterion = nn.BCEWithLogitsLoss()
    
    total_steps = (len(train_loader) * EPOCHS) // GRAD_ACCUM
    warmup_steps = int(total_steps * 0.1)
    scheduler = torch.optim.lr_scheduler.LinearLR(
        optimizer, start_factor=0.1, total_iters=max(warmup_steps, 1)
    )
    
    best_auroc = 0
    best_patient_logits = {}
    
    for epoch in range(EPOCHS):
        model.train()
        train_loss = 0
        optimizer.zero_grad()
        
        for step, batch in enumerate(train_loader):
            audio = batch['audio'].to(DEVICE)
            labels = batch['label'].to(DEVICE)
            loss = criterion(model(audio), labels) / GRAD_ACCUM
            loss.backward()
            
            if (step + 1) % GRAD_ACCUM == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
            train_loss += loss.item() * GRAD_ACCUM
        
        # Flush remaining gradients
        if (step + 1) % GRAD_ACCUM != 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()
        
        train_loss /= len(train_loader)
        
        # Validation
        model.eval()
        seg_probs, seg_labels, seg_pids = [], [], []
        with torch.no_grad():
            for batch in val_loader:
                audio = batch['audio'].to(DEVICE)
                probs = torch.sigmoid(model(audio)).cpu().numpy()
                seg_probs.extend(probs)
                seg_labels.extend(batch['label'].numpy())
                seg_pids.extend(batch['pid'])
        
        # Patient-level aggregation (soft voting)
        pt_p, pt_l = {}, {}
        for pid, prob, lab in zip(seg_pids, seg_probs, seg_labels):
            pt_p.setdefault(pid, []).append(prob)
            pt_l[pid] = lab
        
        yt = np.array([pt_l[p] for p in pt_p])
        yp = np.array([np.mean(v) for v in pt_p.values()])
        auroc = roc_auc_score(yt, yp) if len(np.unique(yt)) > 1 else 0.5
        
        improved = ' *' if auroc > best_auroc else ''
        print(f'  Epoch {epoch+1}/{EPOCHS} \u2014 loss: {train_loss:.4f}, AUROC: {auroc:.4f}{improved}')
        
        if auroc > best_auroc:
            best_auroc = auroc
            best_patient_logits = {pid: np.mean(v) for pid, v in pt_p.items()}
            os.makedirs('checkpoints_t1', exist_ok=True)
            torch.save(model.state_dict(), f'checkpoints_t1/wav2vec2_fold{fold_num}.pt')
    
    del model, optimizer
    if DEVICE.type == 'mps': torch.mps.empty_cache()
    elif DEVICE.type == 'cuda': torch.cuda.empty_cache()
    
    return best_auroc, best_patient_logits

print('Training function ready.')

---
## Cell 5: Run 10-Fold CV

In [None]:
unique_pids = np.unique(all_pids)
pid_labels = np.array([all_labels[all_pids == pid][0] for pid in unique_pids])

n_folds_actual = min(N_FOLDS, int(min(Counter(pid_labels).values()) * 0.8))
skf = StratifiedKFold(n_splits=n_folds_actual, shuffle=True, random_state=SEED)

fold_aurocs = []
all_patient_logits = {}
all_patient_labels = {}

print(f'=== TRIAL 1: Wav2Vec2 Cough-Only Baseline ===')
print(f'Model: Wav2Vec2-base (feature_extractor frozen)')
print(f'Segments: {len(all_segments)} @ 3s')
print(f'Patients: {len(unique_pids)}')
print(f'Folds: {n_folds_actual}, Device: {DEVICE}')
print(f'\n--- Benchmarks ---')
print(f'Zambia study (audio-only): 0.852')
print(f'DREAM Challenge winner:    0.743')
print(f'\nStarting training...\n')

for fold, (train_idx, val_idx) in enumerate(skf.split(unique_pids, pid_labels)):
    print(f'=== Fold {fold+1}/{n_folds_actual} ===')
    
    train_pids_set = set(unique_pids[train_idx])
    val_pids_set = set(unique_pids[val_idx])
    
    tr_s, tr_l, tr_p = [], [], []
    va_s, va_l, va_p = [], [], []
    
    for seg, lab, pid in zip(all_segments, all_labels, all_pids):
        if pid in train_pids_set:
            tr_s.append(seg); tr_l.append(lab); tr_p.append(pid)
        else:
            va_s.append(seg); va_l.append(lab); va_p.append(pid)
    
    print(f'  Train: {len(tr_s)} segments ({len(train_pids_set)} patients)')
    print(f'  Val:   {len(va_s)} segments ({len(val_pids_set)} patients)')
    
    auroc, patient_logits = train_one_fold(
        fold+1, tr_s, tr_l, tr_p, va_s, va_l, va_p
    )
    
    fold_aurocs.append(auroc)
    all_patient_logits.update(patient_logits)
    for pid in val_pids_set:
        all_patient_labels[pid] = pid_labels[unique_pids == pid][0]
    
    print(f'  \u2705 Fold {fold+1} best AUROC: {auroc:.4f}\n')

print('=' * 60)
print(f'TRIAL 1 RESULT')
print(f'Mean AUROC: {np.mean(fold_aurocs):.4f} +/- {np.std(fold_aurocs):.4f}')
print(f'Per-fold: {[f"{a:.3f}" for a in fold_aurocs]}')
print(f'\nDREAM winner:  0.743')
print(f'Zambia study:  0.852')

---
## Cell 6: ROC + Threshold

In [None]:
pids_order = sorted(all_patient_logits.keys())
y_true = np.array([all_patient_labels[p] for p in pids_order])
y_prob = np.array([all_patient_logits[p] for p in pids_order])

fpr, tpr, _ = roc_curve(y_true, y_prob)
auroc = roc_auc_score(y_true, y_prob)

plt.figure(figsize=(8, 7))
plt.plot(fpr, tpr, 'r-', lw=2.5, label=f'Trial 1 ({auroc:.3f})')
plt.plot([0,1],[0,1],'k--', alpha=0.2)
plt.axhspan(0.90, 1.0, xmin=0, xmax=0.30, alpha=0.08, color='green', label='WHO TPP zone')
plt.axhline(0.90, color='r', ls=':', alpha=0.3)
plt.axvline(0.30, color='g', ls=':', alpha=0.3)
plt.xlabel('FPR (1 - Specificity)', fontsize=12)
plt.ylabel('TPR (Sensitivity)', fontsize=12)
plt.title('RespiraHub Trial 1 \u2014 Wav2Vec2 Baseline', fontsize=14)
plt.legend(fontsize=11); plt.grid(alpha=0.15)
plt.tight_layout()
plt.savefig('roc_trial1.png', dpi=150)
plt.show()

print(f'\n{"Thresh":>7} {"Sens":>7} {"Spec":>7} {"PPV":>7} {"NPV":>7}')
print('-' * 40)
best_t, best_j = 0.5, -1
for t in np.arange(0.10, 0.90, 0.05):
    pred = (y_prob >= t).astype(int)
    tp = np.sum((pred == 1) & (y_true == 1))
    tn = np.sum((pred == 0) & (y_true == 0))
    fp = np.sum((pred == 1) & (y_true == 0))
    fn = np.sum((pred == 0) & (y_true == 1))
    sens = tp/(tp+fn) if (tp+fn) else 0
    spec = tn/(tn+fp) if (tn+fp) else 0
    ppv = tp/(tp+fp) if (tp+fp) else 0
    npv = tn/(tn+fn) if (tn+fn) else 0
    j = sens + spec - 1
    if j > best_j: best_j, best_t = j, t
    flag = ' \u2705 WHO' if (sens >= 0.90 and spec >= 0.70) else ''
    print(f'{t:>7.2f} {sens:>7.3f} {spec:>7.3f} {ppv:>7.3f} {npv:>7.3f}{flag}')
print(f'\nBest Youden threshold: {best_t:.2f}')

---
## Cell 7: Save

In [None]:
results_df = pd.DataFrame({
    'participant': pids_order,
    'true_label': y_true,
    'predicted_prob': y_prob,
})
results_df.to_csv('patient_predictions_trial1.csv', index=False)

summary = {
    'trial': 1,
    'approach': 'Wav2Vec2-base, feature_extractor frozen, concat+split 3s',
    'n_participants': len(pids_order),
    'n_segments': len(all_segments),
    'segment_sec': SEGMENT_SEC,
    'n_folds': n_folds_actual,
    'auroc_mean': round(float(np.mean(fold_aurocs)), 4),
    'auroc_std': round(float(np.std(fold_aurocs)), 4),
    'auroc_per_fold': [round(float(a), 4) for a in fold_aurocs],
    'best_threshold': round(float(best_t), 2),
    'seed': SEED,
    'device': str(DEVICE),
}
with open('training_summary_trial1.json', 'w') as f:
    json.dump(summary, f, indent=2)

print('Saved:')
print('  patient_predictions_trial1.csv')
print('  training_summary_trial1.json')
print('  checkpoints_t1/wav2vec2_fold*.pt')
print('  roc_trial1.png')
print(f'\nFinal AUROC: {np.mean(fold_aurocs):.4f} +/- {np.std(fold_aurocs):.4f}')