# Model
### Wav2Vec2

In [31]:
import pandas as pd

In [32]:
import os, json, time, re, random, numpy as np
from pathlib import Path
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

try:
    import torchaudio
    HAVE_TA = True
except Exception:
    HAVE_TA = False
    print("[WARN] torchaudio not found, augmentation disabled.")

import pyedflib
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import (
    accuracy_score, f1_score, precision_recall_fscore_support,
    classification_report, confusion_matrix, roc_auc_score,
    log_loss, average_precision_score
)

from transformers import (
    Wav2Vec2Processor, Wav2Vec2ForSequenceClassification,
    get_linear_schedule_with_warmup
)

In [33]:
# Reproducibility
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

# Prefer CUDA (your RTX 4060). Fallback to CPU if unavailable.
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print("Using device:", device)

Using device: cuda


In [34]:
CSV_PATH = r"C:\V89\Snore_Apnea_Analyze\EDF_RML\data_csv\respiratory_plus_normal.csv"
EDF_ROOT = r"C:\V89\data2"

MODEL_NAME   = "facebook/wav2vec2-base"
MODEL_TAG    = "wav2vec2-base-osa-win4060"

SAMPLE_RATE  = 16000
MAX_SECONDS  = 8      # crop to avoid OOM
BATCH_SIZE   = 4
EPOCHS       = 6
LR           = 2e-5
WARMUP_RATIO = 0.1
FREEZE_BASE  = True        # freeze feature encoder early
UNFREEZE_EPOCH = 2         # unfreeze after this epoch (1-based)
AUGMENT      = True        # light noise/gain
USE_CLASS_WEIGHTS = False  # set True if severe class imbalance

SPLIT_FILE   = Path("./splits/split_indices.json")
OUT_DIR      = Path(f"./eval_out/{MODEL_TAG}")
OUT_DIR.mkdir(parents=True, exist_ok=True)

In [35]:
class SleepApneaDataset(Dataset):
    """
    Expect df with columns: patient_id, type, segment_index, segment_local_start_sec, duration_sec, ...
    Resolve EDF path once and drop rows without matches.
    """
    def __init__(self, df: pd.DataFrame, edf_root: str, sample_rate: int = 16000,
                 prefer_audio_channels=("sound","snore","tracheal","microphone","audio","throat")):
        self.sample_rate = sample_rate
        self.edf_root = Path(edf_root)
        self.prefer_audio_channels = tuple(s.lower() for s in prefer_audio_channels)

        df = df.copy()
        df["pid_str"]   = df["patient_id"].astype(str)
        df["pid_unpad"] = df["pid_str"].str.lstrip("0")
        df["pid_pad8"]  = df["pid_unpad"].str.zfill(8)
        df["seg3"]      = df["segment_index"].astype(int).map(lambda x: f"{x:03d}")

        def resolve_row(row):
            pid_unpad = row["pid_unpad"]
            pid_pad8  = row["pid_pad8"]
            seg3      = row["seg3"]
            patterns = [
                f"*{pid_pad8}*{seg3}*.edf",
                f"*{pid_unpad}*{seg3}*.edf",
                f"*{pid_pad8}*.edf",
                f"*{pid_unpad}*.edf",
            ]
            for pat in patterns:
                hits = list(self.edf_root.rglob(pat))
                if len(hits) == 1:
                    return hits[0]
                if len(hits) > 1:
                    ranked = sorted(
                        hits,
                        key=lambda p: (
                            0 if re.search(rf"{seg3}\b", p.stem) else 1,
                            0 if re.search(r"(snore|sound|trach|mic|psg|audio|throat)", p.stem.lower()) else 1,
                            len(p.as_posix())
                        )
                    )
                    return ranked[0]
            return None

        df["edf_path"] = df.apply(resolve_row, axis=1)
        missing = df["edf_path"].isna().sum()
        if missing:
            print(f"[WARNING] Skipping {missing} rows that have no matching EDF file.")
        df = df.dropna(subset=["edf_path"]).reset_index(drop=True)

        self.df = df
        self.labels = sorted(self.df['type'].unique().tolist())
        self.label_encoder = LabelEncoder().fit(self.labels)

    def __len__(self):
        return len(self.df)

    def _pick_channel_index(self, f: pyedflib.EdfReader):
        labels = [s.lower() for s in f.getSignalLabels()]
        for i, name in enumerate(labels):
            if any(key in name for key in self.prefer_audio_channels):
                return i
        return 0  # fallback

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        edf_path = Path(row["edf_path"])
        start_sec = float(row['segment_local_start_sec'])
        duration_sec = float(row['duration_sec'])
        label = self.label_encoder.transform([row['type']])[0]

        f = pyedflib.EdfReader(str(edf_path))
        ch_idx = self._pick_channel_index(f)
        signal = f.readSignal(ch_idx)
        fs = f.getSampleFrequency(ch_idx)
        f.close()

        start_sample = max(0, int(start_sec * fs))
        end_sample   = min(int((start_sec + duration_sec) * fs), len(signal))
        audio = signal[start_sample:end_sample]

        x = torch.tensor(audio, dtype=torch.float32)
        if fs != self.sample_rate:
            if not HAVE_TA:
                raise RuntimeError("torchaudio is required for resampling")
            x = torchaudio.transforms.Resample(fs, self.sample_rate)(x)

        return {"audio": x, "label": int(label)}

In [36]:
df = pd.read_csv(CSV_PATH)
dataset = SleepApneaDataset(df, edf_root=EDF_ROOT, sample_rate=SAMPLE_RATE)

print("Resolved EDF rows:", len(dataset))
print(dataset.df[["patient_id","pid_unpad","pid_pad8","segment_index","edf_path"]].head(8))

label_names = list(dataset.label_encoder.classes_)
n_classes = len(label_names)
print("Classes:", label_names, " | n_classes =", n_classes)

Resolved EDF rows: 1824
   patient_id pid_unpad  pid_pad8  segment_index  \
0         999       999  00000999              0   
1         999       999  00000999              0   
2         999       999  00000999              0   
3         999       999  00000999              0   
4         999       999  00000999              0   
5         999       999  00000999              0   
6         999       999  00000999              0   
7         999       999  00000999              0   

                                edf_path  
0  C:\V89\data2\00000999-100507[001].edf  
1  C:\V89\data2\00000999-100507[001].edf  
2  C:\V89\data2\00000999-100507[001].edf  
3  C:\V89\data2\00000999-100507[001].edf  
4  C:\V89\data2\00000999-100507[001].edf  
5  C:\V89\data2\00000999-100507[001].edf  
6  C:\V89\data2\00000999-100507[001].edf  
7  C:\V89\data2\00000999-100507[001].edf  
Classes: [np.str_('CentralApnea'), np.str_('Hypopnea'), np.str_('MixedApnea'), np.str_('Normal'), np.str_('ObstructiveAp

In [37]:
if SPLIT_FILE.exists():
    with open(SPLIT_FILE, "r") as f:
        idx = json.load(f)
    train_idx, val_idx, test_idx = idx["train"], idx["val"], idx["test"]
else:
    g = torch.Generator().manual_seed(SEED)
    N = len(dataset)
    perm = torch.randperm(N, generator=g).tolist()
    train_ratio, val_ratio = 0.8, 0.1
    n_train = int(train_ratio * N)
    n_val   = int(val_ratio   * N)
    train_idx = perm[:n_train]
    val_idx   = perm[n_train:n_train+n_val]
    test_idx  = perm[n_train+n_val:]
    SPLIT_FILE.parent.mkdir(parents=True, exist_ok=True)
    with open(SPLIT_FILE, "w") as f:
        json.dump({"train": train_idx, "val": val_idx, "test": test_idx}, f, indent=2)

train_ds = Subset(dataset, train_idx)
val_ds   = Subset(dataset, val_idx)
test_ds  = Subset(dataset, test_idx)

print("Split sizes:", len(train_ds), len(val_ds), len(test_ds))

Split sizes: 1459 182 183


In [38]:
class Wav2Vec2Collator:
    """Prepare raw waveforms for Wav2Vec2 (padding + attention mask + optional aug)."""
    def __init__(self, processor, sr=16000, max_seconds=8, augment=False):
        self.processor = processor
        self.sr = sr
        self.max_len = int(max_seconds * sr)
        self.augment = augment and HAVE_TA

    def _augment(self, x: torch.Tensor) -> torch.Tensor:
        if not self.augment: return x
        # Gaussian noise
        if random.random() < 0.5:
            noise_std = 0.005 * (x.abs().mean().item() + 1e-6)
            x = x + torch.randn_like(x) * noise_std
        # Random gain
        if random.random() < 0.3:
            gain_db = random.uniform(-3.0, 3.0)
            x = x * (10.0 ** (gain_db / 20.0))
        return x

    def __call__(self, batch):
        waves, labels = [], []
        for b in batch:
            w = b["audio"]
            if isinstance(w, np.ndarray):
                w = torch.from_numpy(w)
            w = w.float().view(-1)

            # Crop long clips
            if len(w) > self.max_len:
                start = random.randint(0, len(w) - self.max_len)
                w = w[start:start + self.max_len]

            # Peak normalize
            peak = w.abs().max()
            if peak > 0:
                w = w / peak

            # Light augmentation
            w = self._augment(w)

            waves.append(w.numpy())
            labels.append(int(b["label"]))

        inputs = processor(
            waves,
            sampling_rate=self.sr,
            return_tensors="pt",
            padding=True,
            truncation=False
        )
        inputs["labels"] = torch.tensor(labels, dtype=torch.long)
        return inputs

In [39]:
processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)

model = Wav2Vec2ForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=n_classes,
    ignore_mismatched_sizes=True,
    use_safetensors=True
)
# Label maps in config (clean checkpoints/inference)
model.config.id2label = {i: name for i, name in enumerate(label_names)}
model.config.label2id = {name: i for i, name in enumerate(label_names)}

if FREEZE_BASE:
    model.freeze_feature_encoder()

model.to(device)

Some weights of Wav2Vec2ForSequenceClassification were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'projector.bias', 'projector.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Wav2Vec2ForSequenceClassification(
  (wav2vec2): Wav2Vec2Model(
    (feature_extractor): Wav2Vec2FeatureEncoder(
      (conv_layers): ModuleList(
        (0): Wav2Vec2GroupNormConvLayer(
          (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
          (activation): GELUActivation()
          (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
        )
        (1-4): 4 x Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
        (5-6): 2 x Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
      )
    )
    (feature_projection): Wav2Vec2FeatureProjection(
      (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (projection): Linear(in_features=512, out_features=768, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)


In [40]:
collate_fn = Wav2Vec2Collator(processor, sr=SAMPLE_RATE, max_seconds=MAX_SECONDS, augment=AUGMENT)

# For notebooks on Windows, keep num_workers=0 (pyedflib may not be multiprocess-safe in all setups)
NUM_WORKERS = 0
PIN_MEMORY = (device.type == "cuda")

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          collate_fn=collate_fn, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False,
                          collate_fn=collate_fn, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False,
                          collate_fn=collate_fn, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)



In [41]:
def evaluate_model(model, loader, label_encoder, device, return_arrays=False):
    model.eval()
    y_true, y_pred, y_proba_chunks = [], [], []
    with torch.no_grad():
        for batch in loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            logits = model(
                input_values=batch["input_values"],
                attention_mask=batch.get("attention_mask", None)
            ).logits
            proba = torch.softmax(logits, dim=-1).cpu().numpy()
            y_proba_chunks.append(proba)
            y_pred.extend(np.argmax(proba, axis=1))
            y_true.extend(batch["labels"].cpu().numpy())

    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    y_proba = np.vstack(y_proba_chunks) if len(y_proba_chunks) > 0 else np.zeros((0, len(label_encoder.classes_)))

    class_names = list(label_encoder.classes_)
    n_classes = len(class_names)
    
    # Fix: Use only labels that appear in predictions
    unique_labels = np.unique(np.concatenate([y_true, y_pred]))
    cm = confusion_matrix(y_true, y_pred, labels=unique_labels)

    acc = accuracy_score(y_true, y_pred) if len(y_true) else np.nan
    f1_macro = f1_score(y_true, y_pred, average="macro", zero_division=0) if len(y_true) else np.nan
    f1_weighted = f1_score(y_true, y_pred, average="weighted", zero_division=0) if len(y_true) else np.nan
    prec_macro, rec_macro, _, _ = precision_recall_fscore_support(y_true, y_pred, average="macro", zero_division=0) if len(y_true) else (np.nan, np.nan, None, None)

    cm_float = cm.astype(float) if cm.size else np.zeros((len(unique_labels), len(unique_labels)), float)
    row_sums = cm_float.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0] = 1.0
    bal_acc = (cm_float / row_sums).diagonal().mean() if cm.size else np.nan

    try:
        roc_auc_macro = roc_auc_score(y_true, y_proba, multi_class="ovr", average="macro")
    except Exception:
        roc_auc_macro = np.nan
    try:
        pr_auc_macro = average_precision_score(np.eye(n_classes)[y_true], y_proba, average="macro")
    except Exception:
        pr_auc_macro = np.nan
    try:
        ll = log_loss(y_true, y_proba, labels=list(range(n_classes)))
    except Exception:
        ll = np.nan

    # Fix: Use target_names that match actual labels in predictions
    report_dict = classification_report(
        y_true, y_pred, 
        labels=unique_labels,
        target_names=[class_names[i] for i in unique_labels],
        output_dict=True, zero_division=0
    ) if len(y_true) else {}

    metrics = {
        "accuracy": acc,
        "balanced_accuracy": bal_acc,
        "f1_macro": f1_macro,
        "f1_weighted": f1_weighted,
        "precision_macro": prec_macro,
        "recall_macro": rec_macro,
        "roc_auc_macro": roc_auc_macro,
        "pr_auc_macro": pr_auc_macro,
        "log_loss": ll,
        "support": int(len(y_true))
    }
    out = (metrics, report_dict, cm, class_names)
    if return_arrays:
        return out + (y_true, y_pred, y_proba)
    return out

def save_confusion_matrix(cm, class_names, save_path, normalize=True, title=None):
    cm_plot = cm.astype(float)
    if normalize and cm_plot.size:
        row_sums = cm_plot.sum(axis=1, keepdims=True)
        row_sums[row_sums == 0] = 1.0
        cm_plot = cm_plot / row_sums

    plt.figure(figsize=(6, 5))
    plt.imshow(cm_plot, interpolation='nearest')
    plt.title(title or "Confusion Matrix")
    plt.xlabel("Predicted"); plt.ylabel("True")
    ticks = np.arange(len(class_names))
    plt.xticks(ticks, class_names, rotation=45, ha="right"); plt.yticks(ticks, class_names)

    fmt = ".2f" if normalize else "d"
    thresh = cm_plot.max() / 2. if cm_plot.size else 0.5
    for i in range(cm_plot.shape[0]):
        for j in range(cm_plot.shape[1]):
            val = cm_plot[i, j] if normalize else int(cm[i, j])
            plt.text(j, i, format(val, fmt),
                     ha="center", va="center",
                     color="white" if (cm_plot[i, j] if cm_plot.size else 0) > thresh else "black")
    plt.tight_layout()
    Path(save_path).parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(save_path, bbox_inches="tight"); plt.close()

def save_classification_report(report_dict, save_csv_path):
    if not report_dict:
        return None
    df = pd.DataFrame(report_dict).T
    Path(save_csv_path).parent.mkdir(parents=True, exist_ok=True)
    df.to_csv(save_csv_path, index=True)
    return df

def append_metrics_row(results_csv, model_name, split_name, metrics, extras=None):
    row = {
        "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
        "model": model_name,
        "split": split_name,
        **metrics
    }
    if extras:
        row.update(extras)
    if os.path.exists(results_csv):
        df = pd.read_csv(results_csv)
        df = pd.concat([df, pd.DataFrame([row])], ignore_index=True)
    else:
        df = pd.DataFrame([row])
    Path(results_csv).parent.mkdir(parents=True, exist_ok=True)
    df.to_csv(results_csv, index=False)
    return df.tail(1)

## Train Loop

In [42]:
logits = model(
    input_values=batch["input_values"],
    attention_mask=batch.get("attention_mask", None)
).logits

In [43]:
total_steps  = len(train_loader) * EPOCHS
warmup_steps = int(total_steps * WARMUP_RATIO)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, total_steps)

use_amp = (device.type == "cuda")
scaler  = torch.cuda.amp.GradScaler(enabled=use_amp)

# Optional: class weights
if USE_CLASS_WEIGHTS:
    all_y = dataset.label_encoder.transform(dataset.df['type'])
    from sklearn.utils.class_weight import compute_class_weight
    classes = np.arange(n_classes)
    weights = compute_class_weight(class_weight='balanced', classes=classes, y=all_y)
    class_weights = torch.tensor(weights, dtype=torch.float32, device=device)
    ce_loss = nn.CrossEntropyLoss(weight=class_weights)
else:
    ce_loss = nn.CrossEntropyLoss()

def run_eval_and_log(split_name, loader):
    metrics, report, cm, class_names = evaluate_model(model, loader, dataset.label_encoder, device)
    print(f"{split_name.upper()}:", metrics)
    save_confusion_matrix(cm, class_names, OUT_DIR / f"cm_{split_name}.png", normalize=True, title=f"{split_name.title()} CM (norm)")
    save_classification_report(report, OUT_DIR / f"cls_report_{split_name}.csv")
    append_metrics_row("./eval_out/scoreboard.csv", MODEL_TAG, split_name, metrics, extras={f"n_{split_name}": len(loader.dataset) if hasattr(loader, "dataset") else None})
    return metrics

for epoch in range(1, EPOCHS + 1):
    model.train()
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}")
    running = 0.0

    if FREEZE_BASE and UNFREEZE_EPOCH is not None and epoch == UNFREEZE_EPOCH:
        try:
            model.unfreeze_feature_encoder()
            print("[Info] Unfroze feature encoder.")
        except Exception:
            pass

    for batch in pbar:
        batch = {k: v.to(device) for k, v in batch.items()}
        optimizer.zero_grad(set_to_none=True)

        if use_amp:
            with torch.cuda.amp.autocast():
                outputs = model(
                    input_values=batch["input_values"],
                    attention_mask=batch.get("attention_mask", None)
                )
                logits = outputs.logits
                loss = ce_loss(logits, batch["labels"])
            scaler.scale(loss).backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(
                input_values=batch["input_values"],
                attention_mask=batch.get("attention_mask", None)
            )
            logits = outputs.logits
            loss = ce_loss(logits, batch["labels"])
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

        scheduler.step()
        running += loss.item()
        pbar.set_postfix(loss=running / max(1, pbar.n))

    _ = run_eval_and_log("val", val_loader)

print("Final evaluation on TEST split:")
_ = run_eval_and_log("test", test_loader)

# Save raw preds for analysis
_, _, _, _, y_true_test, y_pred_test, y_proba_test = evaluate_model(
    model, test_loader, dataset.label_encoder, device, return_arrays=True
)
np.save(OUT_DIR / "y_true_test.npy", y_true_test)
np.save(OUT_DIR / "y_pred_test.npy", y_pred_test)
np.save(OUT_DIR / "y_proba_test.npy", y_proba_test)
with open(OUT_DIR / "class_names.json", "w") as f:
    json.dump(label_names, f, ensure_ascii=False, indent=2)

# Save checkpoint
model.save_pretrained("./osa_wav2vec2_ckpt")
processor.save_pretrained("./osa_wav2vec2_ckpt")
print("Done. Artifacts saved to:", OUT_DIR.as_posix())

  scaler  = torch.cuda.amp.GradScaler(enabled=use_amp)
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autoc

VAL: {'accuracy': 0.4835164835164835, 'balanced_accuracy': np.float64(0.25), 'f1_macro': 0.16296296296296298, 'f1_weighted': 0.3151811151811152, 'precision_macro': 0.12087912087912088, 'recall_macro': 0.25, 'roc_auc_macro': nan, 'pr_auc_macro': 0.22285319895687272, 'log_loss': 1.1577661110794466, 'support': 182}


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.c

VAL: {'accuracy': 0.4835164835164835, 'balanced_accuracy': np.float64(0.25), 'f1_macro': 0.16296296296296298, 'f1_weighted': 0.3151811151811152, 'precision_macro': 0.12087912087912088, 'recall_macro': 0.25, 'roc_auc_macro': nan, 'pr_auc_macro': 0.24723067317947964, 'log_loss': 1.1552791773388704, 'support': 182}


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.c

VAL: {'accuracy': 0.4835164835164835, 'balanced_accuracy': np.float64(0.25), 'f1_macro': 0.16296296296296298, 'f1_weighted': 0.3151811151811152, 'precision_macro': 0.12087912087912088, 'recall_macro': 0.25, 'roc_auc_macro': nan, 'pr_auc_macro': 0.21438261324330526, 'log_loss': 1.1655670606704727, 'support': 182}


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.c

VAL: {'accuracy': 0.4835164835164835, 'balanced_accuracy': np.float64(0.25), 'f1_macro': 0.16296296296296298, 'f1_weighted': 0.3151811151811152, 'precision_macro': 0.12087912087912088, 'recall_macro': 0.25, 'roc_auc_macro': nan, 'pr_auc_macro': 0.23094844489018182, 'log_loss': 1.161667129714257, 'support': 182}


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.c

VAL: {'accuracy': 0.4835164835164835, 'balanced_accuracy': np.float64(0.25), 'f1_macro': 0.16296296296296298, 'f1_weighted': 0.3151811151811152, 'precision_macro': 0.12087912087912088, 'recall_macro': 0.25, 'roc_auc_macro': nan, 'pr_auc_macro': 0.2271196566395175, 'log_loss': 1.1677235944628346, 'support': 182}


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.c

VAL: {'accuracy': 0.4835164835164835, 'balanced_accuracy': np.float64(0.25), 'f1_macro': 0.16296296296296298, 'f1_weighted': 0.3151811151811152, 'precision_macro': 0.12087912087912088, 'recall_macro': 0.25, 'roc_auc_macro': nan, 'pr_auc_macro': 0.22217489970147222, 'log_loss': 1.1559508851355018, 'support': 182}
Final evaluation on TEST split:
TEST: {'accuracy': 0.5136612021857924, 'balanced_accuracy': np.float64(0.2), 'f1_macro': 0.13574007220216605, 'f1_weighted': 0.34862204336075436, 'precision_macro': 0.10273224043715848, 'recall_macro': 0.2, 'roc_auc_macro': 0.5187176154659218, 'pr_auc_macro': 0.24845034488786594, 'log_loss': 1.1730836461187333, 'support': 183}
Done. Artifacts saved to: eval_out/wav2vec2-base-osa-win4060


## Evaluate

In [46]:
# A) Quick test on one item from test_ds
model.eval()
with torch.no_grad():
    sample = test_ds[0]
    w = sample["audio"]
    if isinstance(w, np.ndarray): w = torch.from_numpy(w)
    w = w.float().view(-1)
    peak = w.abs().max()
    if peak > 0: w = w / peak
    if len(w) > MAX_SECONDS * SAMPLE_RATE:
        w = w[:MAX_SECONDS * SAMPLE_RATE]
    inp = processor([w.numpy()], sampling_rate=SAMPLE_RATE, return_tensors="pt", padding=True)
    inp = {k: v.to(device) for k, v in inp.items()}
    logits = model(**inp).logits
    pred_id = int(logits.argmax(dim=-1).item())
pred_label = dataset.label_encoder.inverse_transform([pred_id])[0]
print("Predicted label (sample 0 of test):", pred_label)

Predicted label (sample 0 of test): Normal


In [45]:
# B) Predict from an EDF path + timing (if you want manual inference)
def predict_from_edf(edf_path: str, start_sec: float, duration_sec: float,
                     prefer_audio_channels=("sound","snore","tracheal","microphone","audio","throat")):
    f = pyedflib.EdfReader(edf_path)
    labels = [s.lower() for s in f.getSignalLabels()]
    ch_idx = 0
    for i, name in enumerate(labels):
        if any(key in name for key in prefer_audio_channels):
            ch_idx = i; break
    sig = f.readSignal(ch_idx)
    fs = f.getSampleFrequency(ch_idx)
    f.close()

    s = max(0, int(start_sec * fs))
    e = min(int((start_sec + duration_sec) * fs), len(sig))
    w = torch.tensor(sig[s:e], dtype=torch.float32)
    if fs != SAMPLE_RATE:
        if not HAVE_TA:
            raise RuntimeError("torchaudio is required for resampling")
        w = torchaudio.transforms.Resample(fs, SAMPLE_RATE)(w)

    peak = w.abs().max()
    if peak > 0: w = w / peak
    if len(w) > MAX_SECONDS * SAMPLE_RATE:
        w = w[:MAX_SECONDS * SAMPLE_RATE]

    with torch.no_grad():
        inp = processor([w.numpy()], sampling_rate=SAMPLE_RATE, return_tensors="pt", padding=True)
        inp = {k: v.to(device) for k, v in inp.items()}
        logits = model(**inp).logits
        pred_id = int(logits.argmax(dim=-1).item())
    return dataset.label_encoder.inverse_transform([pred_id])[0]

# Example:
# print(predict_from_edf(r"C:\V89\data2\00000999_xyz.edf", start_sec=120, duration_sec=10))

# Use model

In [47]:
from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
import torch

# โหลดโมเดลและ processor
model = Wav2Vec2ForSequenceClassification.from_pretrained("./osa_wav2vec2_ckpt")
processor = Wav2Vec2Processor.from_pretrained("./osa_wav2vec2_ckpt")

# โหลด label encoder (ถ้าต้องการแปลงผลลัพธ์)
import json
with open("./eval_out/wav2vec2-base-osa-win4060/class_names.json", "r") as f:
    class_names = json.load(f)

# ย้ายโมเดลไป GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

Wav2Vec2ForSequenceClassification(
  (wav2vec2): Wav2Vec2Model(
    (feature_extractor): Wav2Vec2FeatureEncoder(
      (conv_layers): ModuleList(
        (0): Wav2Vec2GroupNormConvLayer(
          (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
          (activation): GELUActivation()
          (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
        )
        (1-4): 4 x Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
        (5-6): 2 x Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
      )
    )
    (feature_projection): Wav2Vec2FeatureProjection(
      (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (projection): Linear(in_features=512, out_features=768, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
