# RespiraHub — Trial 4: Push Wav2Vec2 to the Limit

**Context:** 3 trial dengan variasi preprocessing semua stuck di 0.69-0.73. DREAM Challenge winner: 0.743. Bottleneck = model backbone.

**Perubahan dari Trial 1 (best mean 0.733):**
1. Unfreeze last 2 transformer layers (deeper fine-tuning)
2. Audio augmentation (noise, time shift, pitch shift)
3. Epochs: 10 with early stopping (patience 3)
4. Segmentasi: Trial 1 style (concatenate all + split 3s) — proven best mean

**Target:** Beat DREAM Challenge winner (0.743). This is the final Wav2Vec2 attempt before switching to HeAR.

---
## Cell 1: Setup

In [None]:
import os, json, random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torchaudio
import matplotlib.pyplot as plt
from collections import Counter
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score, roc_curve
from torch.utils.data import Dataset, DataLoader
from transformers import Wav2Vec2Model
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

if torch.backends.mps.is_available():
    DEVICE = torch.device('mps')
    print('Using Apple Silicon MPS')
elif torch.cuda.is_available():
    DEVICE = torch.device('cuda')
    print('Using CUDA')
else:
    DEVICE = torch.device('cpu')
    print('Using CPU')

print(f'PyTorch {torch.__version__}, Device: {DEVICE}')

---
## Cell 2: Load Metadata

In [None]:
AUDIO_DIR = '/Users/aida/code/development/tb-datasets/data/solicited/'
CLINICAL_PATH = '/Users/aida/code/development/tb-datasets/data/metadata/CODA_TB_Clinical_Meta_Info.csv'
ADDITIONAL_PATH = '/Users/aida/code/development/tb-datasets/data/metadata/CODA_TB_additional_variables_train.csv'
SOLICITED_PATH = '/Users/aida/code/development/tb-datasets/data/metadata/CODA_TB_Solicited_Meta_Info.csv'

clinical = pd.read_csv(CLINICAL_PATH)
additional = pd.read_csv(ADDITIONAL_PATH)
solicited = pd.read_csv(SOLICITED_PATH)

meta = clinical.merge(additional, on='participant', how='left')
meta['label'] = meta['tb_status'].astype(int)

print(f'Participants: {len(meta)}')
print(f'TB+: {meta["label"].sum()}, TB-: {(meta["label"] == 0).sum()}')
print(f'Prevalence: {meta["label"].mean():.1%}')

df = solicited.merge(meta, on='participant', how='left')
df['filepath'] = df['filename'].apply(lambda f: os.path.join(AUDIO_DIR, f))
df['file_exists'] = df['filepath'].apply(os.path.exists)
df = df[df['file_exists']].reset_index(drop=True)

print(f'Total cough files: {len(df)} from {df["participant"].nunique()} participants')

---
## Cell 3: Trial 1 Segmentation (Concatenate All + Split 3s)

Back to Trial 1 approach — proven best mean AUROC (0.733).

In [None]:
TARGET_SR = 16000
SEGMENT_SEC = 3.0
SEGMENT_SAMPLES = int(TARGET_SR * SEGMENT_SEC)  # 48000
GAP_SAMPLES = int(TARGET_SR * 0.05)  # 50ms gap

def load_audio(filepath, target_sr=16000):
    try:
        w, sr = torchaudio.load(filepath)
        if w.shape[0] > 1:
            w = w.mean(dim=0, keepdim=True)
        if sr != target_sr:
            w = torchaudio.transforms.Resample(sr, target_sr)(w)
        return w.squeeze(0)
    except:
        return None

# Group by participant, concatenate all coughs, split into 3s segments
grouped = df.groupby('participant').agg(
    label=('label', 'first'),
    filepaths=('filepath', list),
).reset_index()

all_segments = []
all_labels = []
all_pids = []
skipped = 0

for _, row in tqdm(grouped.iterrows(), total=len(grouped), desc='Building segments'):
    pid = row['participant']
    label = row['label']
    
    # load all waveforms
    waveforms = []
    for fp in row['filepaths']:
        w = load_audio(fp, TARGET_SR)
        if w is not None and len(w) > 0:
            waveforms.append(w)
    
    if not waveforms:
        skipped += 1
        continue
    
    # concatenate with 50ms gaps
    parts = []
    gap = torch.zeros(GAP_SAMPLES)
    for i, w in enumerate(waveforms):
        parts.append(w)
        if i < len(waveforms) - 1:
            parts.append(gap)
    combined = torch.cat(parts)
    
    # split into 3s segments
    n_segments = max(1, len(combined) // SEGMENT_SAMPLES)
    for i in range(n_segments):
        start = i * SEGMENT_SAMPLES
        end = start + SEGMENT_SAMPLES
        seg = combined[start:end]
        if len(seg) < SEGMENT_SAMPLES:
            # center-pad
            pad_total = SEGMENT_SAMPLES - len(seg)
            pad_left = pad_total // 2
            pad_right = pad_total - pad_left
            seg = torch.nn.functional.pad(seg, (pad_left, pad_right))
        all_segments.append(seg)
        all_labels.append(label)
        all_pids.append(pid)

all_labels = np.array(all_labels)
all_pids = np.array(all_pids)

pid_counts = Counter(all_pids)
counts = list(pid_counts.values())

print(f'\nTotal segments: {len(all_segments)}')
print(f'Skipped: {skipped}')
print(f'Unique participants: {len(np.unique(all_pids))}')
print(f'Segments per patient: min={min(counts)}, max={max(counts)}, mean={np.mean(counts):.1f}')
print(f'\nUsing Trial 1 segmentation (concatenate all + split 3s)')

---
## Cell 4: Audio Augmentation

NEW in Trial 4: augment during training to reduce overfitting.

In [None]:
class AudioAugmenter:
    """Simple audio augmentations for cough data."""
    def __init__(self, p=0.5):
        self.p = p  # probability of each augmentation
    
    def add_noise(self, waveform, snr_db=20):
        """Add Gaussian noise at given SNR."""
        if random.random() > self.p:
            return waveform
        snr = random.uniform(15, 30)  # 15-30 dB
        signal_power = waveform.pow(2).mean()
        noise_power = signal_power / (10 ** (snr / 10))
        noise = torch.randn_like(waveform) * noise_power.sqrt()
        return waveform + noise
    
    def time_shift(self, waveform, max_shift=0.2):
        """Shift audio left/right by up to max_shift fraction."""
        if random.random() > self.p:
            return waveform
        shift = int(len(waveform) * random.uniform(-max_shift, max_shift))
        return torch.roll(waveform, shift)
    
    def amplitude_scale(self, waveform):
        """Scale amplitude randomly."""
        if random.random() > self.p:
            return waveform
        scale = random.uniform(0.7, 1.3)
        return waveform * scale
    
    def __call__(self, waveform):
        waveform = self.add_noise(waveform)
        waveform = self.time_shift(waveform)
        waveform = self.amplitude_scale(waveform)
        return waveform

augmenter = AudioAugmenter(p=0.5)

# Test augmentation
test_seg = all_segments[0].clone()
aug_seg = augmenter(test_seg)
print(f'Original: mean={test_seg.mean():.6f}, std={test_seg.std():.6f}')
print(f'Augmented: mean={aug_seg.mean():.6f}, std={aug_seg.std():.6f}')
print(f'Augmentation ready. Each aug applied with p=0.5.')

---
## Cell 5: Model (Unfreeze Last 2 Transformer Layers) + Dataset

In [None]:
class CoughClassifierV2(nn.Module):
    """Wav2Vec2 with last 2 transformer layers unfrozen."""
    def __init__(self, model_name='facebook/wav2vec2-base', dropout=0.3):
        super().__init__()
        self.wav2vec2 = Wav2Vec2Model.from_pretrained(model_name)
        
        # Freeze everything first
        for param in self.wav2vec2.parameters():
            param.requires_grad = False
        
        # Unfreeze last 2 transformer encoder layers
        n_layers = len(self.wav2vec2.encoder.layers)
        for layer in self.wav2vec2.encoder.layers[n_layers - 2:]:
            for param in layer.parameters():
                param.requires_grad = True
        
        # Unfreeze layer norm
        if hasattr(self.wav2vec2.encoder, 'layer_norm'):
            for param in self.wav2vec2.encoder.layer_norm.parameters():
                param.requires_grad = True
        
        hidden = self.wav2vec2.config.hidden_size  # 768
        self.head = nn.Sequential(
            nn.LayerNorm(hidden),
            nn.Dropout(dropout),
            nn.Linear(hidden, 128),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(128, 1),
        )
    
    def forward(self, x):
        out = self.wav2vec2(x).last_hidden_state
        pooled = out.mean(dim=1)
        return self.head(pooled).squeeze(-1)

class CoughDatasetV2(Dataset):
    def __init__(self, segments, labels, pids, augmenter=None):
        self.segments = segments
        self.labels = torch.tensor(labels, dtype=torch.float32)
        self.pids = pids
        self.augmenter = augmenter
    def __len__(self): return len(self.segments)
    def __getitem__(self, idx):
        audio = self.segments[idx]
        if self.augmenter is not None:
            audio = self.augmenter(audio)
        return {'audio': audio, 'label': self.labels[idx], 'pid': self.pids[idx]}

# === Config ===
BATCH_SIZE = 4
LR = 1e-5  # lower LR for unfrozen layers
EPOCHS = 10
PATIENCE = 3  # early stopping
GRAD_ACCUM = 8
WARMUP_RATIO = 0.1
N_FOLDS = 10

print(f'=== Trial 4 Config ===')
print(f'Batch: {BATCH_SIZE} (effective: {BATCH_SIZE * GRAD_ACCUM})')
print(f'Epochs: {EPOCHS} (early stop patience: {PATIENCE})')
print(f'LR: {LR} (lower for unfrozen transformer layers)')
print(f'Folds: {N_FOLDS}')
print(f'Augmentation: noise + time shift + amplitude scale (p=0.5 each)')
print(f'Classifier head: LayerNorm -> Dropout -> 768->128 GELU -> Dropout -> 128->1')

# Model check
print('\nLoading Wav2Vec2-base (unfreeze last 2 layers)...')
m = CoughClassifierV2()
n_total = sum(p.numel() for p in m.parameters())
n_train = sum(p.numel() for p in m.parameters() if p.requires_grad)
n_frozen = n_total - n_train
print(f'Total params: {n_total/1e6:.1f}M')
print(f'Trainable: {n_train/1e6:.1f}M ({n_train/n_total*100:.1f}%)')
print(f'Frozen: {n_frozen/1e6:.1f}M ({n_frozen/n_total*100:.1f}%)')

# Count unfrozen transformer layers
n_layers = len(m.wav2vec2.encoder.layers)
print(f'\nTransformer layers: {n_layers} total, last 2 unfrozen')
for i, layer in enumerate(m.wav2vec2.encoder.layers):
    trainable = sum(p.requires_grad for p in layer.parameters())
    total = sum(1 for _ in layer.parameters())
    status = 'UNFROZEN' if trainable > 0 else 'frozen'
    if i >= n_layers - 3:  # show last 3
        print(f'  Layer {i}: {status} ({trainable}/{total} params trainable)')

dummy = torch.randn(2, SEGMENT_SAMPLES)
with torch.no_grad(): out = m(dummy)
print(f'\nForward pass OK: {out.shape}')
del m

---
## Cell 6: Training Function (with Early Stopping)

In [None]:
def train_one_fold(fold_num, train_segs, train_labs, train_pids,
                   val_segs, val_labs, val_pids):
    train_ds = CoughDatasetV2(train_segs, train_labs, train_pids, augmenter=augmenter)
    val_ds = CoughDatasetV2(val_segs, val_labs, val_pids, augmenter=None)  # no aug for val
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    
    model = CoughClassifierV2().to(DEVICE)
    
    # Differential LR: lower for transformer, higher for head
    transformer_params = []
    head_params = []
    for name, param in model.named_parameters():
        if param.requires_grad:
            if 'head' in name:
                head_params.append(param)
            else:
                transformer_params.append(param)
    
    optimizer = torch.optim.AdamW([
        {'params': transformer_params, 'lr': LR},        # 1e-5 for transformer
        {'params': head_params, 'lr': LR * 10},           # 1e-4 for head
    ], weight_decay=0.01)
    
    criterion = nn.BCEWithLogitsLoss()
    
    total_steps = (len(train_loader) * EPOCHS) // GRAD_ACCUM
    warmup_steps = int(total_steps * WARMUP_RATIO)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=total_steps, eta_min=1e-7
    )
    
    best_auroc = 0
    best_patient_logits = {}
    patience_counter = 0
    
    for epoch in range(EPOCHS):
        model.train()
        train_loss = 0
        optimizer.zero_grad()
        
        for step, batch in enumerate(train_loader):
            audio = batch['audio'].to(DEVICE)
            labels = batch['label'].to(DEVICE)
            loss = criterion(model(audio), labels) / GRAD_ACCUM
            loss.backward()
            
            if (step + 1) % GRAD_ACCUM == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step(); scheduler.step(); optimizer.zero_grad()
            train_loss += loss.item() * GRAD_ACCUM
        
        train_loss /= len(train_loader)
        
        # Validation
        model.eval()
        seg_probs, seg_labels, seg_pids = [], [], []
        with torch.no_grad():
            for batch in val_loader:
                audio = batch['audio'].to(DEVICE)
                probs = torch.sigmoid(model(audio)).cpu().numpy()
                seg_probs.extend(probs)
                seg_labels.extend(batch['label'].numpy())
                seg_pids.extend(batch['pid'])
        
        # soft voting per patient
        pt_p, pt_l = {}, {}
        for pid, prob, lab in zip(seg_pids, seg_probs, seg_labels):
            pt_p.setdefault(pid, []).append(prob)
            pt_l[pid] = lab
        
        yt = np.array([pt_l[p] for p in pt_p])
        yp = np.array([np.mean(v) for v in pt_p.values()])
        auroc = roc_auc_score(yt, yp) if len(np.unique(yt)) > 1 else 0.5
        
        improved = ''
        if auroc > best_auroc:
            best_auroc = auroc
            best_patient_logits = {pid: np.mean(v) for pid, v in pt_p.items()}
            os.makedirs('checkpoints_t4', exist_ok=True)
            torch.save(model.state_dict(), f'checkpoints_t4/wav2vec2_fold{fold_num}.pt')
            patience_counter = 0
            improved = ' *'
        else:
            patience_counter += 1
        
        print(f'  Epoch {epoch+1}/{EPOCHS} — loss: {train_loss:.4f}, AUROC: {auroc:.4f}{improved}')
        
        if patience_counter >= PATIENCE:
            print(f'  Early stopping at epoch {epoch+1} (no improvement for {PATIENCE} epochs)')
            break
    
    del model, optimizer
    if DEVICE.type == 'mps': torch.mps.empty_cache()
    elif DEVICE.type == 'cuda': torch.cuda.empty_cache()
    
    return best_auroc, best_patient_logits

print('Training function ready.')
print('Features: differential LR, cosine scheduler, early stopping, augmentation')

---
## Cell 7: Run 10-Fold CV

In [None]:
unique_pids = np.unique(all_pids)
pid_labels = np.array([all_labels[all_pids == pid][0] for pid in unique_pids])

n_folds_actual = min(N_FOLDS, int(min(Counter(pid_labels).values()) * 0.8))
skf = StratifiedKFold(n_splits=n_folds_actual, shuffle=True, random_state=42)

fold_aurocs = []
all_patient_logits = {}
all_patient_labels = {}

print(f'=== TRIAL 4: Push Wav2Vec2 to the Limit ===')
print(f'Changes: unfreeze 2 layers, augmentation, 10 epochs, early stop, differential LR')
print(f'Segmentation: Trial 1 (concatenate all + split 3s)')
print(f'Segments: {len(all_segments)}')
print(f'Patients: {len(unique_pids)}')
print(f'Folds: {n_folds_actual}, Device: {DEVICE}')
print(f'\n--- Benchmarks ---')
print(f'DREAM Challenge winner: 0.743')
print(f'Trial 1 (baseline): 0.733')
print(f'Zambia (TB+ vs OR): 0.801')
print(f'\nStarting training...\n')

for fold, (train_idx, val_idx) in enumerate(skf.split(unique_pids, pid_labels)):
    print(f'=== Fold {fold+1}/{n_folds_actual} ===')
    
    train_pids_set = set(unique_pids[train_idx])
    val_pids_set = set(unique_pids[val_idx])
    
    tr_s, tr_l, tr_p = [], [], []
    va_s, va_l, va_p = [], [], []
    
    for seg, lab, pid in zip(all_segments, all_labels, all_pids):
        if pid in train_pids_set:
            tr_s.append(seg); tr_l.append(lab); tr_p.append(pid)
        else:
            va_s.append(seg); va_l.append(lab); va_p.append(pid)
    
    print(f'  Train: {len(tr_s)} segments ({len(train_pids_set)} patients)')
    print(f'  Val:   {len(va_s)} segments ({len(val_pids_set)} patients)')
    
    auroc, patient_logits = train_one_fold(
        fold+1, tr_s, tr_l, tr_p, va_s, va_l, va_p
    )
    
    fold_aurocs.append(auroc)
    all_patient_logits.update(patient_logits)
    for pid in val_pids_set:
        all_patient_labels[pid] = pid_labels[unique_pids == pid][0]
    
    print(f'  \u2705 Fold {fold+1} best AUROC: {auroc:.4f}\n')

print('=' * 50)
print(f'TRIAL 4 RESULT')
print(f'Mean AUROC: {np.mean(fold_aurocs):.4f} +/- {np.std(fold_aurocs):.4f}')
print(f'Per-fold: {[f"{a:.3f}" for a in fold_aurocs]}')
print(f'\n--- Full Comparison ---')
print(f'DREAM winner: 0.743')
print(f'Trial 1: 0.7330 +/- 0.0565')
print(f'Trial 2: 0.6926 +/- 0.0438')
print(f'Trial 3: 0.7247 +/- 0.0724')
print(f'Trial 4: {np.mean(fold_aurocs):.4f} +/- {np.std(fold_aurocs):.4f}')
beat_dream = np.mean(fold_aurocs) > 0.743
print(f'\n{"\u2705 BEAT DREAM CHALLENGE WINNER!" if beat_dream else "\u26a0\ufe0f Did not beat DREAM winner. Switch to HeAR (Trial 5)."}')

---
## Cell 8: Results — ROC + Threshold

In [None]:
pids_order = sorted(all_patient_logits.keys())
y_true = np.array([all_patient_labels[p] for p in pids_order])
y_prob = np.array([all_patient_logits[p] for p in pids_order])

fpr_t4, tpr_t4, _ = roc_curve(y_true, y_prob)
auroc_t4 = roc_auc_score(y_true, y_prob)

plt.figure(figsize=(8, 7))
plt.plot(fpr_t4, tpr_t4, 'r-', lw=2.5, label=f'Trial 4 ({auroc_t4:.3f})')
plt.plot([0,1],[0,1],'k--', alpha=0.2)
plt.axhspan(0.90, 1.0, xmin=0, xmax=0.30, alpha=0.08, color='green', label='WHO TPP zone')
plt.axhline(0.90, color='r', ls=':', alpha=0.3)
plt.axvline(0.30, color='g', ls=':', alpha=0.3)
plt.xlabel('FPR (1 - Specificity)', fontsize=12)
plt.ylabel('TPR (Sensitivity)', fontsize=12)
plt.title('RespiraHub Trial 4 — Wav2Vec2 Optimized', fontsize=14)
plt.legend(fontsize=11); plt.grid(alpha=0.15)
plt.tight_layout()
plt.savefig('roc_trial4.png', dpi=150)
plt.show()

In [None]:
print(f'{"Thresh":>7} {"Sens":>7} {"Spec":>7} {"PPV":>7} {"NPV":>7}')
print('-' * 40)

best_t, best_j = 0.5, -1
who_met = False

for t in np.arange(0.15, 0.85, 0.05):
    pred = (y_prob >= t).astype(int)
    tp = np.sum((pred == 1) & (y_true == 1))
    tn = np.sum((pred == 0) & (y_true == 0))
    fp = np.sum((pred == 1) & (y_true == 0))
    fn = np.sum((pred == 0) & (y_true == 1))
    sens = tp/(tp+fn) if (tp+fn) else 0
    spec = tn/(tn+fp) if (tn+fp) else 0
    ppv = tp/(tp+fp) if (tp+fp) else 0
    npv = tn/(tn+fn) if (tn+fn) else 0
    
    flag = ' \u2705 WHO' if (sens >= 0.90 and spec >= 0.70) else ''
    if sens >= 0.90 and spec >= 0.70: who_met = True
    
    j = sens + spec - 1
    if j > best_j: best_j, best_t = j, t
    
    print(f'{t:>7.2f} {sens:>7.3f} {spec:>7.3f} {ppv:>7.3f} {npv:>7.3f}{flag}')

print(f'\nBest Youden threshold: {best_t:.2f}')

---
## Cell 9: Adversarial Robustness

In [None]:
best_fold = np.argmax(fold_aurocs) + 1
model = CoughClassifierV2().to(DEVICE)
model.load_state_dict(torch.load(f'checkpoints_t4/wav2vec2_fold{best_fold}.pt', map_location=DEVICE))
model.eval()
print(f'Loaded fold {best_fold} (AUROC={fold_aurocs[best_fold-1]:.4f})')

with torch.no_grad():
    noise = torch.randn(20, SEGMENT_SAMPLES).to(DEVICE)
    p_noise = torch.sigmoid(model(noise)).cpu().numpy()
    print(f'White noise  -> P(TB+) = {p_noise.mean():.4f} +/- {p_noise.std():.4f}')
    
    silence = torch.zeros(20, SEGMENT_SAMPLES).to(DEVICE)
    p_silence = torch.sigmoid(model(silence)).cpu().numpy()
    print(f'Silence      -> P(TB+) = {p_silence.mean():.4f} +/- {p_silence.std():.4f}')
    
    t_ax = torch.linspace(0, SEGMENT_SEC, SEGMENT_SAMPLES)
    hum = (torch.sin(2*3.14159*50*t_ax)*0.01).unsqueeze(0).repeat(20,1).to(DEVICE)
    p_hum = torch.sigmoid(model(hum)).cpu().numpy()
    print(f'50Hz hum     -> P(TB+) = {p_hum.mean():.4f} +/- {p_hum.std():.4f}')

print('\nAll should be ~0.50. If >> 0.60, model overfitting to background.')
del model
if DEVICE.type == 'mps': torch.mps.empty_cache()

---
## Cell 10: Save Everything

In [None]:
results_df = pd.DataFrame({
    'participant': pids_order,
    'true_label': y_true,
    'predicted_prob': y_prob,
})
results_df.to_csv('patient_predictions_trial4.csv', index=False)

summary = {
    'trial': 4,
    'changes': [
        'unfreeze last 2 transformer layers',
        'audio augmentation (noise, time shift, amplitude)',
        '10 epochs with early stopping (patience 3)',
        'differential LR (1e-5 transformer, 1e-4 head)',
        'cosine annealing scheduler',
        'deeper classifier head (768->128->1 with GELU)',
        'Trial 1 segmentation (concatenate all + split 3s)',
    ],
    'model': 'Wav2Vec2-base (last 2 layers unfrozen)',
    'dataset': 'CODA TB solicited',
    'n_participants': len(pids_order),
    'n_segments': len(all_segments),
    'segments_per_patient_mean': round(float(np.mean(counts)), 1),
    'segment_sec': SEGMENT_SEC,
    'n_folds': n_folds_actual,
    'auroc_mean': round(float(np.mean(fold_aurocs)), 4),
    'auroc_std': round(float(np.std(fold_aurocs)), 4),
    'auroc_per_fold': [round(float(a), 4) for a in fold_aurocs],
    'best_threshold': round(float(best_t), 2),
    'dream_winner': 0.743,
    'trial1_auroc': 0.7330,
    'trial2_auroc': 0.6926,
    'trial3_auroc': 0.7247,
    'zambia_auroc': 0.852,
    'device': str(DEVICE),
}
with open('training_summary_trial4.json', 'w') as f:
    json.dump(summary, f, indent=2)

print('Saved:')
print('  patient_predictions_trial4.csv')
print('  training_summary_trial4.json')
print('  checkpoints_t4/wav2vec2_fold*.pt')
print('  roc_trial4.png')
print()
print('=' * 50)
print('TRIAL 4 COMPLETE')
print('=' * 50)
print(f'DREAM winner: 0.743')
print(f'Trial 1: 0.7330 +/- 0.0565')
print(f'Trial 2: 0.6926 +/- 0.0438')
print(f'Trial 3: 0.7247 +/- 0.0724')
print(f'Trial 4: {np.mean(fold_aurocs):.4f} +/- {np.std(fold_aurocs):.4f}')
print(f'\nNext: Trial 5 (HeAR backbone) regardless of result.')
print(f'Trial 4 is the final Wav2Vec2 experiment.')