<a href="https://colab.research.google.com/github/shaddo82/AI-/blob/main/model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
try:
    from google.colab import drive
    drive.mount('/content/drive')
except Exception:
    pass

Mounted at /content/drive


In [None]:
# (0) 드라이브 마운트
try:
    from google.colab import drive
    drive.mount('/content/drive')
except Exception:
    pass

# (1) 임포트/기본 설정
from pathlib import Path
import os, time, math, gc, random, warnings, json
import numpy as np
import pandas as pd
import torch
import torch.multiprocessing as mp
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import LambdaLR
from contextlib import nullcontext
import torchaudio
import torchaudio.functional as AF
import soundfile as sf
import matplotlib.pyplot as plt
from sklearn.metrics import (
    classification_report, accuracy_score, precision_recall_fscore_support,
    confusion_matrix, roc_auc_score, average_precision_score,
    precision_recall_curve, roc_curve
)
from sklearn.preprocessing import label_binarize

warnings.filterwarnings("ignore", message=".*TorchAudio.*deprecated.*")
warnings.filterwarnings("ignore", message=".*torchcodec.*")
warnings.filterwarnings("default")

try:
    mp.set_start_method("spawn", force=True)
except RuntimeError:
    pass

# ===== 경로/하이퍼 =====
SAVE_DIR   = "/content/drive/MyDrive/tts_outputs"

# 데이터 루트들 (필요시 수정)
EDGE_ROOT  = Path("/content/drive/MyDrive/tts_outputs/edge_tts")       # tts
GSM_ROOT   = Path("/content/drive/MyDrive/tts_outputs/edge_tts_gsm")   # tts_gsm
ORIG_ROOT  = Path("/content/drive/MyDrive/origin")                      # orig

# 빠른 메타 스캔 옵션
FAST_SCAN       = True
MAX_CSV_PER_DIR = 3

# 오디오/로더
TARGET_SR       = 16000
FIXED_SECONDS   = 2
FIXED_SAMPLES   = TARGET_SR * FIXED_SECONDS

SAMPLE_FRAC     = 1.0
MAX_PER_CLASS   = None

# 3-way split
TRAIN_FRAC, VAL_FRAC, TEST_FRAC = 0.80, 0.20, 0.00

# 학습
EPOCHS          = 5
BATCH_SIZE      = 16
VAL_BATCH_SIZE  = 16
NUM_WORKERS_TR  = 0
NUM_WORKERS_VA  = 0
LOG_EVERY_SEC   = 2
EMPTY_CACHE_EVERY = 400
SEED            = 42

# 모델
FREEZE_BACKBONE = False
LR_HEAD         = 5e-5
LR_FULL         = 1e-6

GRAD_CLIP = 1.0
WARMUP_STEPS = 500

# 3라벨 고정
class_names = ["orig", "tts", "tts_gsm"]
label_map   = {c:i for i,c in enumerate(class_names)}
id2label    = {i:c for i,c in enumerate(class_names)}
label2id    = {v:k for k,v in id2label.items()}

random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# (2) ORIG metadata.csv 자동 생성
def ensure_orig_metadata_csv(orig_root: Path):
    if not orig_root or not orig_root.exists():
        print(f"[orig] 경로가 없습니다: {orig_root}")
        return
    csv_path = orig_root / "metadata.csv"
    if csv_path.exists():
        print(f"[orig] metadata.csv 존재: {csv_path}")
        return
    # orig 폴더의 .wav를 재귀로 수집
    wavs = sorted(orig_root.rglob("*.wav"))
    if not wavs:
        print(f"[orig] wav 파일을 찾지 못했습니다: {orig_root}")
        return
    df = pd.DataFrame({"path": [str(p) for p in wavs], "type": ["orig"]*len(wavs)})
    df.to_csv(csv_path, index=False)
    print(f"[orig] metadata.csv 생성 완료: {csv_path}  files={len(df)}")

ensure_orig_metadata_csv(ORIG_ROOT)

# (3) 메타데이터 로더
def _pick_some_csvs(root: Path, max_csv=3):
    root_csv = root / "metadata.csv"
    if root_csv.exists():
        return [root_csv]
    else:
        return []


def _read_some_paths(csv_paths, default_type):
    out = []
    for csv in csv_paths:
        try:
            df = pd.read_csv(csv)
            if "path" not in df.columns:
                continue
            df["path"] = df["path"].astype(str).str.replace("//","/")
            if "type" not in df.columns:
                df["type"] = default_type
            out.append(df[["path","type"]])
        except Exception:
            pass
    return pd.concat(out, ignore_index=True) if out else pd.DataFrame(columns=["path","type"])

def read_all_metadata_3class():
    metas = []
    for root, default_type in [
        (ORIG_ROOT, "orig"),
        (EDGE_ROOT, "tts"),
        (GSM_ROOT,  "tts_gsm"),
    ]:
        if not root or not root.exists():
            print(f"[scan] skip (not exists): {root}")
            continue
        if FAST_SCAN:
            csvs = _pick_some_csvs(root, max_csv=MAX_CSV_PER_DIR)
            metas.append(_read_some_paths(csvs, default_type))
        else:
            pass
    if not metas:
        raise FileNotFoundError("metadata.csv들을 찾지 못했습니다. 루트 경로/파일을 확인하세요.")
    m = pd.concat(metas, ignore_index=True).drop_duplicates(subset=["path"]).reset_index(drop=True)

    # 유형 정규화
    def norm_type(row):
        t = str(row.get("type","")).lower()
        p = str(row.get("path","")).lower()
        if "gsm" in t or p.endswith("_gsm.wav"): return "tts_gsm"
        if t in {"orig","original","human"} or ("원본" in p): return "orig"
        if "tts" in t or p.endswith("_edge.wav"): return "tts"
        # 기본값은 파일 출처에 따라 추정
        if "tts" in p and "gsm" in p: return "tts_gsm"
        if "tts" in p: return "tts"
        return "orig"
    m["type"] = m.apply(norm_type, axis=1)

    # 존재/크기 필터 + 3라벨만 유지
    m = m[m["type"].isin(class_names)].copy()
    return m

def balance_classes(df, strategy="min"):
    counts = df["type"].value_counts().to_dict()
    print("[before balance]", counts)

    if strategy == "min":
        target = min(counts.values())
    else:
        raise ValueError("Unsupported strategy")

    outs = []
    for t, g in df.groupby("type", group_keys=False):
        outs.append(g.sample(n=target, random_state=SEED))
    df_bal = pd.concat(outs, ignore_index=True)
    print("[after balance]", df_bal["type"].value_counts().to_dict())
    return df_bal

m = read_all_metadata_3class()
m = balance_classes(m, strategy="min")
print(f"[labels] kept types:", m["type"].value_counts().to_dict())
print(f"[total] {len(m):,} rows  FAST_SCAN={FAST_SCAN}")

# (4) 라벨별 샘플링 → 3-way 분할
def stratified_sample(df, frac=0.6, max_per_class=None, seed=42):
    outs=[]
    for lbl, g in df.groupby("type", group_keys=False):
        k = int(len(g)*frac)
        g2 = g.sample(n=k, random_state=seed) if k < len(g) else g
        if max_per_class is not None and len(g2) > max_per_class:
            g2 = g2.sample(n=max_per_class, random_state=seed)
        outs.append(g2)
    return pd.concat(outs, ignore_index=True) if outs else pd.DataFrame(columns=df.columns)

def stratified_split_3(df, train_frac, val_frac, test_frac, seed=42):
    assert abs((train_frac+val_frac+test_frac)-1.0) < 1e-6

    parts=[]
    for lbl, g in df.groupby("type", group_keys=False):
        g = g.sample(frac=1.0, random_state=seed)
        n = len(g)
        n_tr = int(n * train_frac)
        n_val = int(n * val_frac)
        n_te = int(n * test_frac)


        if test_frac == 0:
            n_te = 0

        parts.append((
            g.iloc[:n_tr].assign(_split="train"),
            g.iloc[n_tr:n_tr+n_val].assign(_split="val"),
            g.iloc[n_tr+n_val:n_tr+n_val+n_te].assign(_split="test"),
        ))

    out = pd.concat([p for trio in parts for p in trio], ignore_index=True)
    return (
        out[out["_split"]=="train"].drop(columns=["_split"]).reset_index(drop=True),
        out[out["_split"]=="val"].drop(columns=["_split"]).reset_index(drop=True),
        out[out["_split"]=="test"].drop(columns=["_split"]).reset_index(drop=True),
    )


# 오디오 Augment 예시 (학습 시 적용)
def augment_audio(wav, sr):
    noise = torch.randn(len(wav), dtype=torch.float32) * 0.005
    wav2 = wav.float() + noise
    wav2 = torch.clamp(wav2, -1.0, 1.0)
    return wav2.float()



m = stratified_sample(m, frac=SAMPLE_FRAC, max_per_class=MAX_PER_CLASS, seed=SEED)
assert not m.empty, "샘플링 결과가 비었습니다. SAMPLE_FRAC을 키우거나 경로/메타를 확인하세요."
df_tr, df_val, df_te = stratified_split_3(m, TRAIN_FRAC, VAL_FRAC, TEST_FRAC, seed=SEED)
print(f"[split] train={len(df_tr):,}  val={len(df_val):,}  test={len(df_te):,}")
print("[by_type] train:", df_tr["type"].value_counts().to_dict())
print("[by_type] val  :", df_val["type"].value_counts().to_dict())
print("[by_type] test :", df_te["type"].value_counts().to_dict())

# (5) WAV 로더 유틸 (16k 모노, 고정길이 랜덤/센터 크롭)
def load_wav_16k_mono(path, retry=1):
    # soundfile 우선 → torchaudio 백업
    last_err = None
    for _ in range(max(1, retry+1)):
        try:
            data, sr = sf.read(path, dtype="float32", always_2d=False)
            wav = torch.from_numpy(data)
            if wav.ndim > 1:
                wav = wav.mean(dim=-1)
            if sr != TARGET_SR:
                wav = AF.resample(wav.unsqueeze(0), sr, TARGET_SR).squeeze(0)
            return wav
        except Exception as e:
            last_err = e
            try:
                w, sr = torchaudio.load(path)
                if w.dtype != torch.float32:
                    w = w.to(torch.float32)
                if w.size(0) > 1:
                    w = w.mean(dim=0, keepdim=True)
                if sr != TARGET_SR:
                    w = AF.resample(w, sr, TARGET_SR)
                return w.squeeze(0)
            except Exception as e2:
                last_err = e2
                time.sleep(0.03)
    return None

def fixed_length_random(wav: torch.Tensor, fixed_len: int = FIXED_SAMPLES):
    wav = wav.view(-1)
    n = wav.numel()
    if n == fixed_len: return wav
    if n > fixed_len:
        start = random.randint(0, n - fixed_len)
        return wav[start:start+fixed_len]
    out = torch.zeros(fixed_len, dtype=wav.dtype)
    out[:n] = wav
    return out

def make_loader_from_wavs(df, batch_size, shuffle, num_workers, pin_memory=True, augment=False):
    if df is None or df.empty:
        return None
    paths  = df["path"].tolist()
    labels = df["type"].tolist()

    class _DS(Dataset):
        def __len__(self): return len(paths)
        def __getitem__(self, i):
            if i % 1000 == 0:
                print(f"[WAV-LOAD] idx={i}/{len(paths)}")
            p = paths[i]; t = labels[i]
            try:
                if (not os.path.exists(p)) or os.path.getsize(p) < 100:
                    return None, None
            except Exception:
                return None, None
            wav = load_wav_16k_mono(p, retry=1)
            if wav is None:
                return None, None
            wav = fixed_length_random(wav, FIXED_SAMPLES)
            if augment:
                if random.random() < 0.3:
                    wav = augment_audio(wav, TARGET_SR)
            return wav, label_map[t]

    def _collate(batch):
        keep = [(w,l) for (w,l) in batch if (w is not None and l is not None)]
        if not keep:
            xb = torch.zeros((1, FIXED_SAMPLES), dtype=torch.float32)
            yb = torch.zeros((1,), dtype=torch.long)
            am = torch.ones((1, FIXED_SAMPLES), dtype=torch.long)
            return {"input_values": xb, "attention_mask": am, "labels": yb, "empty": True}
        waves, lbls = zip(*keep)
        xb = torch.stack(waves, dim=0)
        yb = torch.tensor(lbls, dtype=torch.long)
        am = torch.ones((xb.size(0), xb.size(1)), dtype=torch.long)
        return {"input_values": xb, "attention_mask": am, "labels": yb, "empty": False}

    # num_workers=0 (prefetch_factor 미지정) → 첫 배치 정체 방지
    return DataLoader(
        _DS(),
        batch_size=batch_size, shuffle=shuffle,
        num_workers=num_workers, pin_memory=pin_memory,
        persistent_workers=False, collate_fn=_collate
    )

train_dl = make_loader_from_wavs(df_tr,  BATCH_SIZE,   True,  NUM_WORKERS_TR, augment=True)
val_dl   = make_loader_from_wavs(df_val, VAL_BATCH_SIZE, False, NUM_WORKERS_VA, augment=False)
test_dl  = make_loader_from_wavs(df_te,  VAL_BATCH_SIZE, False, NUM_WORKERS_VA, augment=False)

steps_tr = math.ceil(len(df_tr)/BATCH_SIZE) if train_dl else 0
steps_va = math.ceil(len(df_val)/VAL_BATCH_SIZE) if val_dl else 0
steps_te = math.ceil(len(df_te)/VAL_BATCH_SIZE) if test_dl else 0
print(f"[READY] fixed_seconds={FIXED_SECONDS}s  batch={BATCH_SIZE}/{VAL_BATCH_SIZE}  workers={NUM_WORKERS_TR}/{NUM_WORKERS_VA}")

# (6) 모델/학습 루프
from transformers import Wav2Vec2ForSequenceClassification
device = "cuda" if torch.cuda.is_available() else "cpu"
amp_dtype = None
if device=="cuda":
    if torch.cuda.is_bf16_supported(): amp_dtype=torch.bfloat16
    else: amp_dtype=torch.float16
    torch.backends.cudnn.benchmark = True
    try: torch.set_float32_matmul_precision("high")
    except Exception: pass

if train_dl or val_dl:
    model = Wav2Vec2ForSequenceClassification.from_pretrained(
        "facebook/wav2vec2-base",
        num_labels=3, id2label=id2label, label2id=label2id
    ).to(device)

    if FREEZE_BACKBONE:
        for name, p in model.named_parameters():
            if not (name.startswith("classifier.") or name.startswith("projector.")):
                p.requires_grad = False
        lr = LR_HEAD
    else:
        lr = LR_FULL

    # optimizer 구성 (필수)
    head_keys = ["classifier", "projector"]
    params_head, params_body = [], []

    for name, p in model.named_parameters():
        if any(k in name for k in head_keys):
            params_head.append(p)
        else:
            params_body.append(p)

    optimizer = torch.optim.AdamW([
        {"params": params_head, "lr": LR_HEAD},
        {"params": params_body, "lr": LR_FULL},
    ], weight_decay=0.01)

    # warmup 스케줄
    def warmup_lambda(step):
        if step < WARMUP_STEPS:
            return float(step) / float(max(1, WARMUP_STEPS))
        return 1.0

    scheduler = LambdaLR(optimizer, warmup_lambda)

    scaler = torch.amp.GradScaler("cuda")

    class EMA:
        def __init__(self, alpha=0.2): self.alpha=alpha; self.sps=None; self.t=None
        def tick(self, n):
            now=time.time()
            if self.t is None: self.t=now; return
            dt=max(now-self.t,1e-6); inst=n/dt
            self.sps = inst if self.sps is None else (self.alpha*inst+(1-self.alpha)*self.sps)
            self.t=now
        def sec_per_step(self, bs):
            if not self.sps or self.sps<=0: return float("inf")
            return bs/self.sps
    ema_tr, ema_va = EMA(0.2), EMA(0.2)
    total_steps_all = EPOCHS*(steps_tr+steps_va)

    def pretty_eta(sec: float) -> str:
        if sec is None or not math.isfinite(sec): return "estimating..."
        m, s = divmod(int(sec+0.5), 60); h, m = divmod(m, 60)
        return f"{h}h {m:02d}m {s:02d}s" if h else (f"{m}m {s:02d}s" if m else f"{s}s")

    # no_grad 블록 제거 + train/valid 모드 분리

    def run_epoch(dl, train=True, epoch_idx=0):
        if dl is None or dl.dataset is None:
            print(f"Skip {'train' if train else 'valid'} (empty)")
            return float('inf')

        model.train(train)
        phase = "train" if train else "valid"
        steps_phase = steps_tr if train else steps_va
        bs_phase    = BATCH_SIZE if train else VAL_BATCH_SIZE
        tot_loss, seen, t0 = 0.0, 0, time.time()
        seen_total = 0

        ctx = nullcontext() if not amp_dtype else torch.autocast(
            device_type=device, dtype=amp_dtype, enabled=(device=="cuda")
        )

        # ★ train이 아닐 때만 no_grad 사용
        grad_ctx = nullcontext() if train else torch.inference_mode()

        with grad_ctx:
            for step, batch in enumerate(dl, 1):
                if batch.get("empty", False):
                    continue

                xb = batch["input_values"].to(device, non_blocking=True)
                am = batch["attention_mask"].to(device, non_blocking=True)
                yb = batch["labels"].to(device, non_blocking=True)

                with ctx:
                    out  = model(input_values=xb, attention_mask=am, labels=yb)
                    loss = out.loss

                if train:
                    optimizer.zero_grad(set_to_none=True)

                    if device == "cuda":
                        scaler.scale(loss).backward()
                        scaler.unscale_(optimizer)
                        torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
                        scaler.step(optimizer)
                        scaler.update()
                    else:
                        loss.backward()
                        torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
                        optimizer.step()

                    scheduler.step()

                bs = xb.size(0)
                tot_loss += loss.item() * bs
                seen += bs
                seen_total += bs
                (ema_tr if train else ema_va).tick(bs)

                if device == "cuda" and (step % EMPTY_CACHE_EVERY == 0):
                    torch.cuda.empty_cache()

                now = time.time()
                if now - t0 >= LOG_EVERY_SEC:
                    sps = seen / max(now - t0, 1e-6)
                    sec_per = (ema_tr if train else ema_va).sec_per_step(bs_phase)
                    done_all = (epoch_idx*(steps_tr+steps_va)) + (step if train else steps_tr + step)
                    pct = 100.0 * done_all / max(1, total_steps_all)
                    print(f"[{phase} E{epoch_idx+1}/{EPOCHS} step {step}/{steps_phase}] "
                          f"loss={tot_loss/max(1,seen):.4f}  speed={sps:.1f} samp/s  "
                          f"prog={pct:.1f}%  ETA(phase)={pretty_eta((steps_phase-step)*sec_per)}")
                    t0 = now; seen = 0

                del xb, am, yb, out, loss

        return tot_loss / max(1, seen_total)


    for ep in range(EPOCHS):
        tr_loss = run_epoch(train_dl, True, ep)
        if device=="cuda": torch.cuda.empty_cache(); gc.collect()
        va_loss = run_epoch(val_dl, False, ep)
        if device=="cuda": torch.cuda.empty_cache(); gc.collect()
        print(f"[E{ep}] train={tr_loss:.4f}  valid={va_loss:.4f}")

    # (7) 테스트 평가 + 시각화 (전체 test_dl)
    def eval_collect(dl):
        if dl is None or dl.dataset is None:
            return None, None, None
        model.eval()
        losses=[]; y_true=[]; logits_all=[]
        with torch.inference_mode(), (torch.autocast(device_type=device, dtype=amp_dtype, enabled=(device=="cuda")) if amp_dtype else nullcontext()):
            for batch in dl:
                if batch.get("empty", False): continue
                xb = batch["input_values"].to(device, non_blocking=True)
                am = batch["attention_mask"].to(device, non_blocking=True)
                yb = batch["labels"].to(device, non_blocking=True)
                out = model(input_values=xb, attention_mask=am, labels=yb)
                losses.append(out.loss.item())
                logits_all.append(out.logits.float().detach().cpu().numpy())  # bf16→fp32
                y_true.extend(yb.cpu().tolist())
        logits = np.concatenate(logits_all, axis=0) if logits_all else np.zeros((0, len(class_names)))
        return float(np.mean(losses)) if losses else float('inf'), np.array(y_true), logits

    te_loss, y_true, logits = eval_collect(test_dl)
    out_eval_dir = Path(f"{SAVE_DIR}/eval_reports_train_3class_nocache")
    out_eval_dir.mkdir(parents=True, exist_ok=True)

    if y_true is None or y_true.size == 0:
        print("[TEST] no test samples")
    else:
        probs = torch.softmax(torch.tensor(logits), dim=1).numpy()
        y_pred = probs.argmax(axis=1)

        acc = accuracy_score(y_true, y_pred)
        prec_ma, rec_ma, f1_ma, _ = precision_recall_fscore_support(y_true, y_pred, average="macro",  zero_division=0)
        prec_mi, rec_mi, f1_mi, _ = precision_recall_fscore_support(y_true, y_pred, average="micro",  zero_division=0)
        prec_w,  rec_w,  f1_w,  _ = precision_recall_fscore_support(y_true, y_pred, average="weighted", zero_division=0)

        labels_order = list(range(len(class_names)))
        report = classification_report(
            y_true, y_pred, labels=labels_order, target_names=class_names,
            zero_division=0, output_dict=True
        )
        rep_df = pd.DataFrame(report).T
        rep_csv = out_eval_dir / "test_classification_report.csv"
        rep_df.to_csv(rep_csv)

        # Robust binarize (항상 (N,3)로)
        Y_bin = label_binarize(y_true, classes=labels_order)
        if Y_bin.ndim == 1 or Y_bin.shape[1] != len(class_names):
            # 부족하면 더미 열 채우기
            pad = len(class_names) - (Y_bin.shape[1] if Y_bin.ndim > 1 else 1)
            if Y_bin.ndim == 1:
                Y_bin = np.column_stack([1 - Y_bin, Y_bin])
                pad = len(class_names) - Y_bin.shape[1]
            if pad > 0:
                Y_bin = np.hstack([Y_bin, np.zeros((Y_bin.shape[0], pad), dtype=Y_bin.dtype)])

        # AP per class / mAP
        ap_per_class = []
        for c in range(len(class_names)):
            ap_c = average_precision_score(Y_bin[:, c], probs[:, c]) if Y_bin[:, c].sum() > 0 else float("nan")
            ap_per_class.append(ap_c)
        mAP = np.nanmean(ap_per_class)

        # ROC-AUC(OVR)
        try:
            roc_macro = roc_auc_score(Y_bin, probs, average="macro", multi_class="ovr")
        except Exception:
            roc_macro = float("nan")

        # Confusion Matrix
        cm = confusion_matrix(y_true, y_pred, labels=labels_order)
        cm_norm = cm.astype("float") / cm.sum(axis=1, keepdims=True)

        def plot_cm(cm, labels, fname):
            fig, ax = plt.subplots(figsize=(5,4), dpi=140)
            im = ax.imshow(cm, interpolation='nearest')
            ax.figure.colorbar(im, ax=ax)
            ax.set(xticks=np.arange(cm.shape[1]), yticks=np.arange(cm.shape[0]),
                   xticklabels=labels, yticklabels=labels, ylabel='True', xlabel='Predicted',
                   title='Confusion Matrix (normalized)')
            plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
            thresh = np.nanmax(cm) * 0.6
            for i in range(cm.shape[0]):
                for j in range(cm.shape[1]):
                    val = cm[i, j]
                    ax.text(j, i, f"{val:.2f}", ha="center", va="center",
                            color="white" if val > thresh else "black")
            fig.tight_layout(); fig.savefig(fname, bbox_inches="tight"); plt.close(fig)

        cm_png = out_eval_dir / "test_confusion_matrix_norm.png"
        plot_cm(cm_norm, class_names, cm_png)

        def plot_pr(Y, P, labels, out_png):
            fig, ax = plt.subplots(figsize=(6,5), dpi=140)
            for c in range(P.shape[1]):
                if Y[:, c].sum() == 0: continue
                prec, rec, _ = precision_recall_curve(Y[:, c], P[:, c])
                ap_c = average_precision_score(Y[:, c], P[:, c])
                ax.plot(rec, prec, label=f"{labels[c]} (AP={ap_c:.3f})")
            ax.set_xlabel("Recall"); ax.set_ylabel("Precision"); ax.set_title("Precision–Recall Curves (Test)")
            ax.legend(); fig.tight_layout(); fig.savefig(out_png, bbox_inches="tight"); plt.close(fig)

        def plot_roc(Y, P, labels, out_png):
            fig, ax = plt.subplots(figsize=(6,5), dpi=140)
            for c in range(P.shape[1]):
                if Y[:, c].sum() == 0: continue
                fpr, tpr, _ = roc_curve(Y[:, c], P[:, c])
                auc_c = np.trapz(tpr, fpr)
                ax.plot(fpr, tpr, label=f"{labels[c]} (AUC~={auc_c:.3f})")
            ax.plot([0,1],[0,1],'--')
            ax.set_xlabel("FPR"); ax.set_ylabel("TPR"); ax.set_title("ROC Curves (OVR, Test)")
            ax.legend(); fig.tight_layout(); fig.savefig(out_png, bbox_inches="tight"); plt.close(fig)

        pr_png  = out_eval_dir / "test_pr_curves.png"
        roc_png = out_eval_dir / "test_roc_curves.png"
        plot_pr(Y_bin, probs, class_names, pr_png)
        plot_roc(Y_bin, probs, class_names, roc_png)

        np.savez_compressed(out_eval_dir / "test_raw_outputs.npz", logits=logits, y_true=y_true)

        summary = {
            "samples": int(len(y_true)),
            "classes": class_names,
            "test_loss": float(te_loss),
            "accuracy": float(acc),
            "precision_macro": float(prec_ma),
            "recall_macro": float(rec_ma),
            "f1_macro": float(f1_ma),
            "precision_micro": float(prec_mi),
            "recall_micro": float(rec_mi),
            "f1_micro": float(f1_mi),
            "precision_weighted": float(prec_w),
            "recall_weighted": float(rec_w),
            "f1_weighted": float(f1_w),
            "AP_per_class": {
                class_names[i]: (None if np.isnan(ap_per_class[i]) else float(ap_per_class[i]))
                for i in range(len(class_names))
            },
            "mAP": (None if np.isnan(mAP) else float(mAP)),
            "roc_auc_macro_ovr": (None if (roc_macro!=roc_macro) else float(roc_macro)),
            "report_csv": str(rep_csv),
            "cm_png": str(cm_png),
            "pr_png": str(pr_png),
            "roc_png": str(roc_png),
            "raw_npz": str(out_eval_dir / "test_raw_outputs.npz"),
        }
        with open(out_eval_dir / "test_metrics_summary.json", "w") as f:
            json.dump(summary, f, indent=2)

        print("\n=== TEST SUMMARY (3-class) ===")
        for k, v in summary.items():
            if isinstance(v, float): print(f"{k:>24}: {v:.4f}")
            else:                    print(f"{k:>24}: {v}")
        print(f"\n✔ 결과 저장: {out_eval_dir}")

    # (8) 모델 저장
    model_dir  = f"{SAVE_DIR}/wav2vec2_clf_model3"
    state_path = f"{SAVE_DIR}/wav2vec2_clf3.pt"
    Path(model_dir).mkdir(parents=True, exist_ok=True)
    model.save_pretrained(model_dir)
    torch.save(model.state_dict(), state_path)
    print(f"\n✔ 모델 저장 완료: {model_dir} & {state_path}")

else:
    print("Skipping training: no train/val data.")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
[orig] metadata.csv 존재: /content/drive/MyDrive/origin/metadata.csv


  return datetime.utcnow().replace(tzinfo=utc)


[before balance] {'tts': 254300, 'tts_gsm': 254292, 'orig': 54066}
[after balance] {'orig': 54066, 'tts': 54066, 'tts_gsm': 54066}
[labels] kept types: {'orig': 54066, 'tts': 54066, 'tts_gsm': 54066}
[total] 162,198 rows  FAST_SCAN=True


  return datetime.utcnow().replace(tzinfo=utc)


[split] train=77,853  val=19,461  test=0
[by_type] train: {'orig': 25951, 'tts': 25951, 'tts_gsm': 25951}
[by_type] val  : {'orig': 6487, 'tts': 6487, 'tts_gsm': 6487}
[by_type] test : {}
[READY] fixed_seconds=2s  batch=16/16  workers=0/0


  return datetime.utcnow().replace(tzinfo=utc)
Some weights of Wav2Vec2ForSequenceClassification were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'projector.bias', 'projector.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  return data.pin_memory(device)
  return data.pin_memory(device)


[train E1/5 step 1/4866] loss=1.0938  speed=0.9 samp/s  prog=0.0%  ETA(phase)=estimating...
[train E1/5 step 2/4866] loss=2.2824  speed=0.8 samp/s  prog=0.0%  ETA(phase)=27h 54m 43s
[train E1/5 step 3/4866] loss=3.5529  speed=0.5 samp/s  prog=0.0%  ETA(phase)=30h 26m 15s


KeyboardInterrupt: 

In [None]:
import os
import torch
import numpy as np
import soundfile as sf
import torchaudio
import random
from pathlib import Path
from sklearn.preprocessing import label_binarize
from sklearn.metrics import (
    classification_report, confusion_matrix, accuracy_score,
    roc_auc_score, roc_curve
)
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2Processor
import matplotlib.pyplot as plt


# ======================
# 경로 설정
# ======================
BASE_TEST_DIR = "/content/drive/MyDrive/test"
MODEL_DIR     = "/content/drive/MyDrive/tts_outputs/wav2vec2_clf_model2"

device = "cuda" if torch.cuda.is_available() else "cpu"

print("[INFO] Loading model from", MODEL_DIR)

model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_DIR).to(device)
model.eval()

processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")

label_map = {"orig": 0, "tts": 1, "tts_gsm": 2}
id2label  = {v: k for k, v in label_map.items()}


# ======================
# WAV 로더
# ======================
def load_wav_safe(path, target_sr=16000):
    wav, sr = sf.read(path, dtype="float32", always_2d=False)
    wav = torch.tensor(wav, dtype=torch.float32)

    if wav.ndim > 1:
        wav = wav.mean(dim=-1)

    if sr != target_sr:
        wav = torchaudio.functional.resample(wav.unsqueeze(0), sr, target_sr).squeeze(0)

    return wav


# ======================
# 예측 함수
# ======================
def predict_logits(wav):
    inputs = processor(wav, sampling_rate=16000, return_tensors="pt", padding="longest")
    input_values = inputs["input_values"]

    # attention mask 직접 생성
    attention_mask = torch.ones_like(input_values, dtype=torch.long)

    with torch.no_grad():
        logits = model(
            input_values=input_values.to(device),
            attention_mask=attention_mask.to(device)
        ).logits

    return logits.cpu().numpy()[0]  # 1D logits


# ======================
# 폴더 테스트 (10% 샘플링)
# ======================
y_true = []
y_pred = []
y_prob = []   # softmax probability 저장

SAMPLE_FRAC = 0.5

for class_name in ["orig", "tts", "tts_gsm"]:
    folder = Path(BASE_TEST_DIR) / class_name
    if not folder.exists():
        print(f"[WARN] 폴더 없음: {folder}")
        continue

    file_list = list(folder.glob("*.wav")) + list(folder.glob("*.mp3"))
    total_files = len(file_list)

    k = max(1, int(total_files * SAMPLE_FRAC))
    file_list = random.sample(file_list, k)
    print(f"[TEST] {class_name}: total={total_files}, sampled={k}")

    for f in file_list:
        try:
            wav = load_wav_safe(str(f))
            logits = predict_logits(wav)
            probs = torch.softmax(torch.tensor(logits), dim=-1).numpy()

            pred = np.argmax(probs)

            y_true.append(label_map[class_name])
            y_pred.append(pred)
            y_prob.append(probs)

        except Exception as e:
            print(f"[ERR] {f}: {e}")


# numpy 변환
y_true = np.array(y_true)
y_pred = np.array(y_pred)
y_prob = np.array(y_prob)


# ======================
# 기본 지표
# ======================
accuracy = accuracy_score(y_true, y_pred)
print("\n전체 Accuracy:", accuracy)

cm = confusion_matrix(y_true, y_pred)

# normalized confusion matrix
cm_norm = cm.astype(float) / cm.sum(axis=1, keepdims=True)


# ======================
# Confusion Matrix 시각화 (normalized, softmax 기반)
# ======================
plt.figure(figsize=(6, 5))
plt.imshow(cm_norm, cmap="Blues")
plt.title("Confusion Matrix (Normalized)")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.xticks([0,1,2], ["orig","tts","tts_gsm"])
plt.yticks([0,1,2], ["orig","tts","tts_gsm"])

for i in range(cm_norm.shape[0]):
    for j in range(cm_norm.shape[1]):
        plt.text(j, i, f"{cm_norm[i, j]:.2f}", ha="center", va="center",
                 color="white" if cm_norm[i, j] > 0.5 else "black")

plt.tight_layout()
plt.show()


# ======================
# ROC / AUC 계산
# ======================
y_bin = label_binarize(y_true, classes=[0,1,2])  # One-hot encoding

auc_macro = roc_auc_score(y_bin, y_prob, average="macro", multi_class="ovr")
print("\nMacro AUC:", auc_macro)

# ======================
# ROC Curve 그리기
# ======================
plt.figure(figsize=(6, 5))

for i, name in enumerate(["orig", "tts", "tts_gsm"]):
    fpr, tpr, _ = roc_curve(y_bin[:, i], y_prob[:, i])
    auc_c = roc_auc_score(y_bin[:, i], y_prob[:, i])
    plt.plot(fpr, tpr, label=f"{name} (AUC={auc_c:.3f})")

plt.plot([0,1],[0,1],'--', color='gray')
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curves (One-vs-Rest)")
plt.legend()
plt.tight_layout()
plt.show()


# ======================
# Overall Accuracy 이미지 출력
# ======================
plt.figure(figsize=(5,3))
plt.text(0.5, 0.5, f"Overall Accuracy: {accuracy:.4f}\nMacro AUC: {auc_macro:.4f}",
         ha="center", va="center", fontsize=17)
plt.axis("off")
plt.show()


[INFO] Loading model from /content/drive/MyDrive/tts_outputs/wav2vec2_clf_model2


preprocessor_config.json:   0%|          | 0.00/159 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/163 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/291 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/85.0 [00:00<?, ?B/s]

[TEST] orig: total=5406, sampled=2703


In [None]:
# ================================================================
#   Wav2Vec2-base Baseline vs Fine-Tuned 3-class Model Performance
#   Test Dataset = 전체 데이터 중 10% 자동 분리
#   단일 셀 전체 코드
# ================================================================

import os
import random
import numpy as np
import librosa
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import classification_report, roc_auc_score
from sklearn.preprocessing import label_binarize
from sklearn.model_selection import train_test_split
from transformers import Wav2Vec2Processor, Wav2Vec2Model, Wav2Vec2ForSequenceClassification

# ---------------------------
# 설정
# ---------------------------

# 데이터 전체 루트 (orig/tts/tts_gsm 폴더가 있어야 함)
DATA_ROOT = "/content/drive/MyDrive/test"       # 예: 전체 데이터 위치
FINETUNED_MODEL_DIR = "/content/drive/MyDrive/tts_outputs/wav2vec2_clf_model2"    # 너의 fine-tuned 모델 dir

labels = ["orig", "tts", "tts_gsm"]
label_map = {"orig": 0, "tts": 1, "tts_gsm": 2}

processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
device = "cuda" if torch.cuda.is_available() else "cpu"

# ---------------------------
# 전체 파일 목록 불러오기
# ---------------------------
paths = []
y = []

for cls in labels:
    folder = os.path.join(DATA_ROOT, cls)
    for fname in os.listdir(folder):
        if fname.lower().endswith(".wav"):
            paths.append(os.path.join(folder, fname))
            y.append(label_map[cls])

paths = np.array(paths)
y = np.array(y)

# ---------------------------
# 데이터 10%를 test로 분리
# ---------------------------
train_paths, test_paths, train_labels, test_labels = train_test_split(
    paths,
    y,
    test_size=0.1,
    random_state=42,
    stratify=y
)

print(f"전체 데이터: {len(paths)}개")
print(f"테스트 데이터: {len(test_paths)}개 (10%)")
print(f"학습 데이터: {len(train_paths)}개 (사용 안 함, 단지 분리만 수행)")


# ================================================================
# Baseline 모델 정의: Wav2Vec2-base + 랜덤 Linear Head
# ================================================================

class W2V2_Baseline(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
        hidden = self.backbone.config.hidden_size
        self.classifier = nn.Linear(hidden, 3)  # 3-class random head

    def forward(self, input_values, attention_mask=None):
        out = self.backbone(input_values=input_values, attention_mask=attention_mask)
        hidden_states = out.last_hidden_state      # [B, T, H]
        pooled = hidden_states.mean(dim=1)         # [B, H]
        logits = self.classifier(pooled)
        return logits


baseline_model = W2V2_Baseline().to(device)
baseline_model.eval()

finetuned_model = Wav2Vec2ForSequenceClassification.from_pretrained(
    FINETUNED_MODEL_DIR
).to(device)
finetuned_model.eval()


# ================================================================
# 평가 함수
# ================================================================
def evaluate(model, model_name):

    y_true = []
    y_pred = []
    y_prob = []

    print(f"\n\n===== {model_name} 평가 시작 =====")

    for path, label in zip(test_paths, test_labels):

        wav, sr = librosa.load(path, sr=16000)

        inputs = processor(
            wav,
            sampling_rate=16000,
            return_tensors="pt"
        )
        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.no_grad():

            outputs = model(**inputs)

            # ----- baseline 모델(직접 정의) -----
            if hasattr(outputs, "logits"):       # fine-tuned 모델
                logits = outputs.logits
            else:                                # baseline 모델
                logits = outputs

            probs = F.softmax(logits, dim=-1).cpu().numpy()[0]
            pred = np.argmax(probs)

        y_true.append(label)
        y_pred.append(pred)
        y_prob.append(probs)

    y_prob = np.array(y_prob)
    y_true_bin = label_binarize(y_true, classes=[0, 1, 2])

    auc_macro = roc_auc_score(y_true_bin, y_prob, average="macro")
    auc_micro = roc_auc_score(y_true_bin, y_prob, average="micro")

    print("AUC-macro:", auc_macro)
    print("AUC-micro:", auc_micro)
    print("\nClassification Report:")
    print(classification_report(y_true, y_pred, target_names=labels))

    return auc_macro, auc_micro



# ================================================================
# 실행
# ================================================================
baseline_auc_macro, baseline_auc_micro = evaluate(baseline_model, "Baseline (Wav2Vec2-base + Random Head)")
fine_auc_macro, fine_auc_micro = evaluate(finetuned_model, "Fine-Tuned Model")


# ================================================================
# 최종 비교 요약 출력
# ================================================================
print("\n\n===== 최종 성능 비교 요약 =====")
print(f"Baseline AUC-macro: {baseline_auc_macro}")
print(f"Fine-Tuned AUC-macro: {fine_auc_macro}")

improvement = (fine_auc_macro - baseline_auc_macro) / baseline_auc_macro * 100
print(f"\nAUC-macro 성능 향상률: {improvement:.2f}%")




전체 데이터: 39587개
테스트 데이터: 3959개 (10%)
학습 데이터: 35628개 (사용 안 함, 단지 분리만 수행)


  return datetime.utcnow().replace(tzinfo=utc)




===== Baseline (Wav2Vec2-base + Random Head) 평가 시작 =====
AUC-macro: 0.6491372652856797
AUC-micro: 0.5242667947625067

Classification Report:
              precision    recall  f1-score   support

        orig       0.07      0.01      0.02       541
         tts       0.00      0.00      0.00      1857
     tts_gsm       0.39      0.98      0.56      1561

    accuracy                           0.39      3959
   macro avg       0.16      0.33      0.19      3959
weighted avg       0.16      0.39      0.22      3959



===== Fine-Tuned Model 평가 시작 =====


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  return datetime.utcnow().replace(tzinfo=utc)


AUC-macro: 0.9999990322308209
AUC-micro: 0.9999993938883917

Classification Report:
              precision    recall  f1-score   support

        orig       1.00      1.00      1.00       541
         tts       1.00      1.00      1.00      1857
     tts_gsm       1.00      1.00      1.00      1561

    accuracy                           1.00      3959
   macro avg       1.00      1.00      1.00      3959
weighted avg       1.00      1.00      1.00      3959



===== 최종 성능 비교 요약 =====
Baseline AUC-macro: 0.6491372652856797
Fine-Tuned AUC-macro: 0.9999990322308209

AUC-macro 성능 향상률: 54.05%


  return datetime.utcnow().replace(tzinfo=utc)
