# RespiraHub — Trial 2: Fixed Segmentation

**Perubahan dari Trial 1:**
1. Segmentasi: pad per-file ke 3s (bukan concatenate+split)
2. Hyfe filter: >= 0.5 (bukan 0.8)
3. Sisanya sama (Wav2Vec2-base, 10-fold CV, same hyperparams)

**Target:** AUROC >= 0.80 (ideally >= 0.85)

---
## Cell 1: Setup

In [None]:
import os, json
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torchaudio
import matplotlib.pyplot as plt
from collections import Counter
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score, roc_curve
from torch.utils.data import Dataset, DataLoader
from transformers import Wav2Vec2Model
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

if torch.backends.mps.is_available():
    DEVICE = torch.device('mps')
    print('Using Apple Silicon MPS')
elif torch.cuda.is_available():
    DEVICE = torch.device('cuda')
    print('Using CUDA')
else:
    DEVICE = torch.device('cpu')
    print('Using CPU')

print(f'PyTorch {torch.__version__}, Device: {DEVICE}')

---
## Cell 2: Load & Merge Metadata (Hyfe filter 0.5)

In [None]:
AUDIO_DIR = '/Users/aida/code/development/tb-datasets/data/solicited/'
CLINICAL_PATH = '/Users/aida/code/development/tb-datasets/data/metadata/CODA_TB_Clinical_Meta_Info.csv'
ADDITIONAL_PATH = '/Users/aida/code/development/tb-datasets/data/metadata/CODA_TB_additional_variables_train.csv'
SOLICITED_PATH = '/Users/aida/code/development/tb-datasets/data/metadata/CODA_TB_Solicited_Meta_Info.csv'

clinical = pd.read_csv(CLINICAL_PATH)
additional = pd.read_csv(ADDITIONAL_PATH)
solicited = pd.read_csv(SOLICITED_PATH)

meta = clinical.merge(additional, on='participant', how='left')
meta['label'] = (meta['tb_status'] == 'positive').astype(int)

print(f'Participants: {len(meta)}')
print(f'TB+: {meta["label"].sum()}, TB-: {(meta["label"] == 0).sum()}')
print(f'Prevalence: {meta["label"].mean():.1%}')

df = solicited.merge(meta, on='participant', how='left')

# === CHANGE: Hyfe filter 0.5 instead of 0.8 ===
print(f'\nTotal cough files: {len(df)}')
print(f'Trial 1 (>=0.8): {len(df[df["sound_prediction_score"] >= 0.8])}')
df = df[df['sound_prediction_score'] >= 0.5].reset_index(drop=True)
print(f'Trial 2 (>=0.5): {len(df)}')

df['filepath'] = df['filename'].apply(lambda f: os.path.join(AUDIO_DIR, f))
df['file_exists'] = df['filepath'].apply(os.path.exists)
df = df[df['file_exists']].reset_index(drop=True)

print(f'\nFinal: {len(df)} cough files from {df["participant"].nunique()} participants')

# coughs per patient
coughs_per_pt = df.groupby('participant').size()
print(f'\nCoughs per patient:')
print(f'  Min: {coughs_per_pt.min()}, Max: {coughs_per_pt.max()}')
print(f'  Mean: {coughs_per_pt.mean():.1f}, Median: {coughs_per_pt.median():.1f}')

---
## Cell 3: New Segmentation — Pad Per-File

**Trial 1 (SALAH):** concatenate semua cough > split 3s > banyak zero padding, 58% patient cuma 1 segment

**Trial 2 (FIX):** setiap file cough individual di-pad ke 3s = 1 segment per file. Patient dengan 8 files > 8 segments.

In [None]:
TARGET_SR = 16000
SEGMENT_SEC = 3.0
SEGMENT_SAMPLES = int(TARGET_SR * SEGMENT_SEC)  # 48000

def load_and_pad(filepath, target_sr=16000, segment_samples=48000):
    """Load single cough file, resample to 16kHz, center-pad to 3s."""
    try:
        w, sr = torchaudio.load(filepath)
        if w.shape[0] > 1:
            w = w.mean(dim=0, keepdim=True)
        if sr != target_sr:
            w = torchaudio.transforms.Resample(sr, target_sr)(w)
        w = w.squeeze(0)
        
        if len(w) < segment_samples:
            pad_total = segment_samples - len(w)
            pad_left = pad_total // 2
            pad_right = pad_total - pad_left
            w = torch.nn.functional.pad(w, (pad_left, pad_right))
        elif len(w) > segment_samples:
            start = (len(w) - segment_samples) // 2
            w = w[start:start + segment_samples]
        
        return w
    except:
        return None

print(f'Target: {TARGET_SR}Hz, {SEGMENT_SEC}s = {SEGMENT_SAMPLES} samples')
print(f'Strategy: pad each file individually (center-padded)')

In [None]:
# === Build segments: 1 file = 1 segment ===
all_segments = []
all_labels = []
all_pids = []
skipped = 0

for _, row in tqdm(df.iterrows(), total=len(df), desc='Loading audio'):
    seg = load_and_pad(row['filepath'], TARGET_SR, SEGMENT_SAMPLES)
    if seg is not None:
        all_segments.append(seg)
        all_labels.append(row['label'])
        all_pids.append(row['participant'])
    else:
        skipped += 1

all_labels = np.array(all_labels)
all_pids = np.array(all_pids)

print(f'\nTotal segments: {len(all_segments)}')
print(f'Skipped: {skipped}')
print(f'Unique participants: {len(np.unique(all_pids))}')

# === Compare with Trial 1 ===
pid_counts = Counter(all_pids)
counts = list(pid_counts.values())
print(f'\nSegments per patient (Trial 2):')
print(f'  Min: {min(counts)}, Max: {max(counts)}')
print(f'  Mean: {np.mean(counts):.1f}, Median: {np.median(counts):.1f}')
print(f'  1 segment only: {sum(1 for c in counts if c == 1)}')
print(f'  >=3 segments: {sum(1 for c in counts if c >= 3)}')
print(f'  >=5 segments: {sum(1 for c in counts if c >= 5)}')
print(f'\n--- Trial 1 was: 1818 segments, mean 1.7/patient, 626 with only 1 ---')

---
## Cell 4: Audio Content Check

In [None]:
audio_ratios = []
for seg in all_segments[:500]:
    nonzero = (seg.abs() > 1e-6).sum().item()
    audio_ratios.append(nonzero / len(seg))

print(f'Audio content ratio (sample of 500):')
print(f'  Mean: {np.mean(audio_ratios):.1%}')
print(f'  Min: {np.min(audio_ratios):.1%}')
print(f'  Max: {np.max(audio_ratios):.1%}')
print(f'  <10% audio: {sum(1 for r in audio_ratios if r < 0.1)}')
print(f'  <20% audio: {sum(1 for r in audio_ratios if r < 0.2)}')
print(f'\nNote: ~17% audio content expected (0.5s cough in 3s segment)')

---
## Cell 5: Model + Dataset + Config

In [None]:
class CoughClassifier(nn.Module):
    def __init__(self, model_name='facebook/wav2vec2-base', dropout=0.5):
        super().__init__()
        self.wav2vec2 = Wav2Vec2Model.from_pretrained(model_name)
        self.wav2vec2.feature_extractor._freeze_parameters()
        hidden = self.wav2vec2.config.hidden_size
        self.head = nn.Sequential(nn.Dropout(dropout), nn.Linear(hidden, 1))
    
    def forward(self, x):
        out = self.wav2vec2(x).last_hidden_state
        pooled = out.mean(dim=1)
        return self.head(pooled).squeeze(-1)

class CoughDataset(Dataset):
    def __init__(self, segments, labels, pids):
        self.segments = segments
        self.labels = torch.tensor(labels, dtype=torch.float32)
        self.pids = pids
    def __len__(self): return len(self.segments)
    def __getitem__(self, idx):
        return {'audio': self.segments[idx], 'label': self.labels[idx], 'pid': self.pids[idx]}

BATCH_SIZE = 4
LR = 3e-5
EPOCHS = 5
GRAD_ACCUM = 8
WARMUP_RATIO = 0.1
N_FOLDS = 10

print(f'Batch: {BATCH_SIZE} (effective: {BATCH_SIZE * GRAD_ACCUM})')
print(f'Epochs: {EPOCHS}, LR: {LR}, Folds: {N_FOLDS}')

print('\nLoading Wav2Vec2-base...')
m = CoughClassifier()
n_train = sum(p.numel() for p in m.parameters() if p.requires_grad)
print(f'Trainable params: {n_train/1e6:.1f}M')
dummy = torch.randn(2, SEGMENT_SAMPLES)
with torch.no_grad(): out = m(dummy)
print(f'Forward pass OK: {out.shape}')
del m

---
## Cell 6: Training Function

In [None]:
def train_one_fold(fold_num, train_segs, train_labs, train_pids,
                   val_segs, val_labs, val_pids):
    train_ds = CoughDataset(train_segs, train_labs, train_pids)
    val_ds = CoughDataset(val_segs, val_labs, val_pids)
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    
    model = CoughClassifier().to(DEVICE)
    optimizer = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=0.01
    )
    criterion = nn.BCEWithLogitsLoss()
    
    total_steps = (len(train_loader) * EPOCHS) // GRAD_ACCUM
    warmup_steps = int(total_steps * WARMUP_RATIO)
    scheduler = torch.optim.lr_scheduler.LinearLR(
        optimizer, start_factor=0.1, total_iters=max(warmup_steps, 1)
    )
    
    best_auroc = 0
    best_patient_logits = {}
    
    for epoch in range(EPOCHS):
        model.train()
        train_loss = 0
        optimizer.zero_grad()
        
        for step, batch in enumerate(train_loader):
            audio = batch['audio'].to(DEVICE)
            labels = batch['label'].to(DEVICE)
            loss = criterion(model(audio), labels) / GRAD_ACCUM
            loss.backward()
            
            if (step + 1) % GRAD_ACCUM == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step(); scheduler.step(); optimizer.zero_grad()
            train_loss += loss.item() * GRAD_ACCUM
        
        train_loss /= len(train_loader)
        
        model.eval()
        seg_probs, seg_labels, seg_pids = [], [], []
        with torch.no_grad():
            for batch in val_loader:
                audio = batch['audio'].to(DEVICE)
                probs = torch.sigmoid(model(audio)).cpu().numpy()
                seg_probs.extend(probs)
                seg_labels.extend(batch['label'].numpy())
                seg_pids.extend(batch['pid'])
        
        pt_p, pt_l = {}, {}
        for pid, prob, lab in zip(seg_pids, seg_probs, seg_labels):
            pt_p.setdefault(pid, []).append(prob)
            pt_l[pid] = lab
        
        yt = np.array([pt_l[p] for p in pt_p])
        yp = np.array([np.mean(v) for v in pt_p.values()])
        auroc = roc_auc_score(yt, yp) if len(np.unique(yt)) > 1 else 0.5
        
        print(f'  Epoch {epoch+1}/{EPOCHS} — loss: {train_loss:.4f}, AUROC: {auroc:.4f}')
        
        if auroc > best_auroc:
            best_auroc = auroc
            best_patient_logits = {pid: np.mean(v) for pid, v in pt_p.items()}
            os.makedirs('checkpoints_t2', exist_ok=True)
            torch.save(model.state_dict(), f'checkpoints_t2/wav2vec2_fold{fold_num}.pt')
    
    del model, optimizer
    if DEVICE.type == 'mps': torch.mps.empty_cache()
    elif DEVICE.type == 'cuda': torch.cuda.empty_cache()
    
    return best_auroc, best_patient_logits

print('Training function ready.')

---
## Cell 7: Run 10-Fold CV

In [None]:
unique_pids = np.unique(all_pids)
pid_labels = np.array([all_labels[all_pids == pid][0] for pid in unique_pids])

n_folds_actual = min(N_FOLDS, int(min(Counter(pid_labels).values()) * 0.8))
skf = StratifiedKFold(n_splits=n_folds_actual, shuffle=True, random_state=42)

fold_aurocs = []
all_patient_logits = {}
all_patient_labels = {}

print(f'=== TRIAL 2: Fixed Segmentation ===')
print(f'Segments: {len(all_segments)} (Trial 1: 1818)')
print(f'Patients: {len(unique_pids)}')
print(f'Folds: {n_folds_actual}')
print(f'Device: {DEVICE}')
print(f'\nStarting training...\n')

for fold, (train_idx, val_idx) in enumerate(skf.split(unique_pids, pid_labels)):
    print(f'=== Fold {fold+1}/{n_folds_actual} ===')
    
    train_pids_set = set(unique_pids[train_idx])
    val_pids_set = set(unique_pids[val_idx])
    
    tr_s, tr_l, tr_p = [], [], []
    va_s, va_l, va_p = [], [], []
    
    for seg, lab, pid in zip(all_segments, all_labels, all_pids):
        if pid in train_pids_set:
            tr_s.append(seg); tr_l.append(lab); tr_p.append(pid)
        else:
            va_s.append(seg); va_l.append(lab); va_p.append(pid)
    
    print(f'  Train: {len(tr_s)} segments ({len(train_pids_set)} patients)')
    print(f'  Val:   {len(va_s)} segments ({len(val_pids_set)} patients)')
    
    auroc, patient_logits = train_one_fold(
        fold+1, tr_s, tr_l, tr_p, va_s, va_l, va_p
    )
    
    fold_aurocs.append(auroc)
    all_patient_logits.update(patient_logits)
    for pid in val_pids_set:
        all_patient_labels[pid] = pid_labels[unique_pids == pid][0]
    
    print(f'  ✅ Fold {fold+1} best AUROC: {auroc:.4f}\n')

print('=' * 50)
print(f'TRIAL 2 RESULT')
print(f'Mean AUROC: {np.mean(fold_aurocs):.4f} +/- {np.std(fold_aurocs):.4f}')
print(f'Per-fold: {[f"{a:.3f}" for a in fold_aurocs]}')
print(f'\n--- Comparison ---')
print(f'Trial 1: 0.7330 +/- 0.0565')
print(f'Trial 2: {np.mean(fold_aurocs):.4f} +/- {np.std(fold_aurocs):.4f}')
print(f'Zambia:  0.852')
print(f'Delta:   {np.mean(fold_aurocs) - 0.7330:+.4f} vs Trial 1')

---
## Cell 8: Results — ROC + Threshold

In [None]:
pids_order = sorted(all_patient_logits.keys())
y_true = np.array([all_patient_labels[p] for p in pids_order])
y_prob = np.array([all_patient_logits[p] for p in pids_order])

fpr_t2, tpr_t2, _ = roc_curve(y_true, y_prob)
auroc_t2 = roc_auc_score(y_true, y_prob)

plt.figure(figsize=(8, 7))
plt.plot(fpr_t2, tpr_t2, 'r-', lw=2.5, label=f'Trial 2 ({auroc_t2:.3f})')
plt.plot([0,1],[0,1],'k--', alpha=0.2)
plt.axhspan(0.90, 1.0, xmin=0, xmax=0.30, alpha=0.08, color='green', label='WHO TPP zone')
plt.axhline(0.90, color='r', ls=':', alpha=0.3)
plt.axvline(0.30, color='g', ls=':', alpha=0.3)
plt.xlabel('FPR (1 - Specificity)', fontsize=12)
plt.ylabel('TPR (Sensitivity)', fontsize=12)
plt.title('RespiraHub Trial 2 — Fixed Segmentation', fontsize=14)
plt.legend(fontsize=11); plt.grid(alpha=0.15)
plt.tight_layout()
plt.savefig('roc_trial2.png', dpi=150)
plt.show()

In [None]:
print(f'{"Thresh":>7} {"Sens":>7} {"Spec":>7} {"PPV":>7} {"NPV":>7}')
print('-' * 40)

best_t, best_j = 0.5, -1
who_met = False

for t in np.arange(0.15, 0.85, 0.05):
    pred = (y_prob >= t).astype(int)
    tp = np.sum((pred == 1) & (y_true == 1))
    tn = np.sum((pred == 0) & (y_true == 0))
    fp = np.sum((pred == 1) & (y_true == 0))
    fn = np.sum((pred == 0) & (y_true == 1))
    sens = tp/(tp+fn) if (tp+fn) else 0
    spec = tn/(tn+fp) if (tn+fp) else 0
    ppv = tp/(tp+fp) if (tp+fp) else 0
    npv = tn/(tn+fn) if (tn+fn) else 0
    
    flag = ' ✅ WHO' if (sens >= 0.90 and spec >= 0.70) else ''
    if sens >= 0.90 and spec >= 0.70: who_met = True
    
    j = sens + spec - 1
    if j > best_j: best_j, best_t = j, t
    
    print(f'{t:>7.2f} {sens:>7.3f} {spec:>7.3f} {ppv:>7.3f} {npv:>7.3f}{flag}')

print(f'\nBest Youden threshold: {best_t:.2f}')
if who_met:
    print('✅ WHO TPP achievable!')
else:
    print('⚠️ WHO TPP not met. Phase 2 (+ anamnesis/domain adapt) should help.')

---
## Cell 9: Adversarial Robustness

In [None]:
best_fold = np.argmax(fold_aurocs) + 1
model = CoughClassifier().to(DEVICE)
model.load_state_dict(torch.load(f'checkpoints_t2/wav2vec2_fold{best_fold}.pt', map_location=DEVICE))
model.eval()
print(f'Loaded fold {best_fold} (AUROC={fold_aurocs[best_fold-1]:.4f})')

with torch.no_grad():
    noise = torch.randn(20, SEGMENT_SAMPLES).to(DEVICE)
    p_noise = torch.sigmoid(model(noise)).cpu().numpy()
    print(f'White noise  -> P(TB+) = {p_noise.mean():.4f} +/- {p_noise.std():.4f}')
    
    silence = torch.zeros(20, SEGMENT_SAMPLES).to(DEVICE)
    p_silence = torch.sigmoid(model(silence)).cpu().numpy()
    print(f'Silence      -> P(TB+) = {p_silence.mean():.4f} +/- {p_silence.std():.4f}')
    
    t = torch.linspace(0, SEGMENT_SEC, SEGMENT_SAMPLES)
    hum = (torch.sin(2*3.14159*50*t)*0.01).unsqueeze(0).repeat(20,1).to(DEVICE)
    p_hum = torch.sigmoid(model(hum)).cpu().numpy()
    print(f'50Hz hum     -> P(TB+) = {p_hum.mean():.4f} +/- {p_hum.std():.4f}')

print('\nAll should be ~0.50. If >> 0.60, model overfitting to background.')
del model
if DEVICE.type == 'mps': torch.mps.empty_cache()

---
## Cell 10: Save Everything

In [None]:
results_df = pd.DataFrame({
    'participant': pids_order,
    'true_label': y_true,
    'predicted_prob': y_prob,
})
results_df.to_csv('patient_predictions_trial2.csv', index=False)

summary = {
    'trial': 2,
    'changes': ['per-file padding (not concatenate+split)', 'Hyfe filter >= 0.5 (was 0.8)'],
    'model': 'Wav2Vec2-base (cough-only)',
    'dataset': 'CODA TB solicited',
    'n_participants': len(pids_order),
    'n_segments': len(all_segments),
    'segments_per_patient_mean': round(float(np.mean(counts)), 1),
    'segment_sec': SEGMENT_SEC,
    'hyfe_threshold': 0.5,
    'n_folds': n_folds_actual,
    'auroc_mean': round(float(np.mean(fold_aurocs)), 4),
    'auroc_std': round(float(np.std(fold_aurocs)), 4),
    'auroc_per_fold': [round(float(a), 4) for a in fold_aurocs],
    'best_threshold': round(float(best_t), 2),
    'trial1_auroc': 0.7330,
    'zambia_auroc': 0.852,
    'device': str(DEVICE),
}
with open('training_summary_trial2.json', 'w') as f:
    json.dump(summary, f, indent=2)

print('Saved:')
print('  patient_predictions_trial2.csv')
print('  training_summary_trial2.json')
print('  checkpoints_t2/wav2vec2_fold*.pt')
print('  roc_trial2.png')
print()
print('=' * 50)
print('TRIAL 2 COMPLETE')
print('=' * 50)
print(f'Trial 1: 0.7330 +/- 0.0565')
print(f'Trial 2: {np.mean(fold_aurocs):.4f} +/- {np.std(fold_aurocs):.4f}')
print(f'Delta:   {np.mean(fold_aurocs) - 0.7330:+.4f}')
print(f'Zambia:  0.852')
print(f'\nNext: if AUROC >= 0.80 -> validate on longitudinal data')
print(f'       if AUROC < 0.80 -> add audio augmentation (Trial 3)')