# AST (Audio Spectrogram Transformer) on ESC-50  
このノートブックは、論文 **AST: Audio Spectrogram Transformer** (Gong et al., 2021) を **ESC-50** で試すための “ノンストップ” 実験ノートです。  

やること：
1. ESC-50 を読み込んで train/val/test に分割（既存 `src` を再利用）
2. `MIT/ast-finetuned-audioset-10-10-0.4593` をベースに fine-tune（分類 head を 50 クラスに差し替え）
3. test 推論 → `inference_summary.csv` 出力 → accuracy / macro-F1 / weighted-F1
4. **入力特徴（ASTFeatureExtractorの mel-fbank）** と **エンコード後特徴（pooler_output）** を UMAP で比較し、  
   「クラス分離が学習で良くなったか」を目で確認

注意：
- このノートは **editable install (`pip install -e .`)** 済みを想定しています。未実施なら最初に実行してください。  
- torchaudio の `torchcodec` 依存でハマらないよう、wav 読み込みは **soundfile** を使います（HF docs でもOKと明記）。  


In [1]:
# If you haven't done it yet (project rootで):
# pip install -e .

import sys, os
print("python:", sys.version)


python: 3.12.3 (tags/v3.12.3:f6650f9, Apr  9 2024, 14:05:25) [MSC v.1938 64 bit (AMD64)]


## 1) 依存関係（必要ならインストール）

In [None]:
# 必要に応じて。すでに入っているならスキップでOK。
# !pip install -U transformers umap-learn soundfile scipy scikit-learn pyyaml


## 2) プロジェクトルート探索 + config 読み込み

In [2]:
from __future__ import annotations

from pathlib import Path
import yaml
import numpy as np
import pandas as pd
import torch

def find_project_root(start: Path | None = None) -> Path:
    start = (start or Path.cwd()).resolve()
    for p in [start] + list(start.parents):
        if (p / "config").exists() and (p / "src").exists():
            return p
    raise FileNotFoundError("Project root not found. Open this notebook somewhere under the project folder.")

ROOT = find_project_root()
print("ROOT =", ROOT)

CFG_PATH = ROOT / "config" / "default.yaml"
cfg = yaml.safe_load(CFG_PATH.read_text(encoding="utf-8"))
print("Loaded config:", CFG_PATH)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)


ROOT = C:\Users\hirok\code\02_sound\esc50_cnn
Loaded config: C:\Users\hirok\code\02_sound\esc50_cnn\config\default.yaml
device: cuda


## 3) ESC-50 metadata / split（既存 `src` を再利用）

In [3]:
from src.data.metadata import load_esc50_metadata
from src.data.split import SplitConfig, make_splits

meta_csv = (ROOT / Path(cfg["paths"]["meta_csv"])).resolve()
audio_dir = (ROOT / Path(cfg["paths"]["audio_dir"])).resolve()

df = load_esc50_metadata(meta_csv)
print("meta rows:", len(df), "num_classes:", df["target"].nunique())
assert df["target"].nunique() == 50

s_cfg = cfg["split"]
split_cfg = SplitConfig(
    method=s_cfg["method"],
    train_folds=list(s_cfg.get("train_folds", [])),
    val_folds=list(s_cfg.get("val_folds", [])),
    test_folds=list(s_cfg.get("test_folds", [])),
    kfold_num_folds=int(s_cfg.get("kfold", {}).get("num_folds", 5)),
    kfold_test_fold=int(s_cfg.get("kfold", {}).get("test_fold", 1)),
    kfold_val_fold=int(s_cfg.get("kfold", {}).get("val_fold", 2)),
    random_seed=int(cfg["project"]["seed"]),
)
splits = make_splits(df, split_cfg)
print({k: len(v) for k, v in splits.items()})


meta rows: 2000 num_classes: 50
{'train': 1200, 'val': 400, 'test': 400}


## 4) AST のセットアップ（FeatureExtractor / Model）

In [4]:
import time, math
import soundfile as sf
from scipy.signal import resample_poly

from transformers import ASTFeatureExtractor, ASTForAudioClassification

PRETRAINED = "MIT/ast-finetuned-audioset-10-10-0.4593"

feature_extractor = ASTFeatureExtractor.from_pretrained(PRETRAINED)

# label mapping (ESC-50)
target_to_cat = df.sort_values("target").drop_duplicates("target").set_index("target")["category"].to_dict()
id2label = {int(t): str(target_to_cat[int(t)]) for t in range(50)}
label2id = {v: k for k, v in id2label.items()}

model = ASTForAudioClassification.from_pretrained(
    PRETRAINED,
    num_labels=50,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True,  # headがAudioSet(527)→ESC-50(50)でサイズ不一致になるため
)

model.to(device)
print("loaded:", PRETRAINED)
print("model num_labels:", model.config.num_labels)
print("feature_extractor sampling_rate:", feature_extractor.sampling_rate)
print("feature_extractor max_length:", feature_extractor.max_length, "num_mel_bins:", feature_extractor.num_mel_bins)


  from .autonotebook import tqdm as notebook_tqdm
Some weights of ASTForAudioClassification were not initialized from the model checkpoint at MIT/ast-finetuned-audioset-10-10-0.4593 and are newly initialized because the shapes did not match:
- classifier.dense.bias: found shape torch.Size([527]) in the checkpoint and torch.Size([50]) in the model instantiated
- classifier.dense.weight: found shape torch.Size([527, 768]) in the checkpoint and torch.Size([50, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


loaded: MIT/ast-finetuned-audioset-10-10-0.4593
model num_labels: 50
feature_extractor sampling_rate: 16000
feature_extractor max_length: 1024 num_mel_bins: 128


## 5) Dataset / DataLoader（soundfile + resample + ASTFeatureExtractor）

In [8]:
from torch.utils.data import Dataset, DataLoader

def load_wav_mono_resampled(path: Path, target_sr: int) -> np.ndarray:
    x, sr = sf.read(str(path), always_2d=False)
    if x.ndim == 2:
        x = x.mean(axis=1)  # stereo -> mono
    x = x.astype(np.float32)
    if sr != target_sr:
        # resample_poly: good quality & fast
        g = math.gcd(sr, target_sr)
        up = target_sr // g
        down = sr // g
        x = resample_poly(x, up=up, down=down).astype(np.float32)
    return x

class Esc50AstDataset(Dataset):
    def __init__(self, df_sub: pd.DataFrame, audio_dir: Path, fe: ASTFeatureExtractor):
        self.df = df_sub.reset_index(drop=True)
        self.audio_dir = audio_dir
        self.fe = fe
        self.sr = int(fe.sampling_rate)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        fn = row["filename"]
        y = int(row["target"])
        wav_path = self.audio_dir / fn
        x = load_wav_mono_resampled(wav_path, self.sr)
        return {"audio": x, "label": y, "filename": fn}

def collate_ast(batch, fe: ASTFeatureExtractor):
    audios = [b["audio"] for b in batch]
    labels = torch.tensor([b["label"] for b in batch], dtype=torch.long)
    fns = [b["filename"] for b in batch]

    feats = fe(
        audios,
        sampling_rate=fe.sampling_rate,
        return_tensors="pt",
        padding="max_length",
        truncation=True,
    )
    # feats["input_values"]: [B, max_length, num_mel_bins]
    return feats["input_values"], labels, fns

dl_cfg = cfg["dataloader"]
batch_size = int(dl_cfg.get("batch_size", 32))
num_workers = int(dl_cfg.get("num_workers", 0))
# Recommended default on Windows notebooks to avoid multiprocessing/pickling issues.
if os.name == "nt":
    num_workers = 0

ds_train = Esc50AstDataset(splits["train"], audio_dir, feature_extractor)
ds_val   = Esc50AstDataset(splits["val"], audio_dir, feature_extractor)
ds_test  = Esc50AstDataset(splits["test"], audio_dir, feature_extractor)

def collate_ast_fixed(batch):
    # NOTE (Windows): DataLoader with num_workers>0 requires picklable functions.
    # Avoid lambda to prevent PicklingError.
    return collate_ast(batch, feature_extractor)

train_loader = DataLoader(ds_train, batch_size=batch_size, shuffle=True,  num_workers=num_workers,
                          collate_fn=collate_ast_fixed)
val_loader   = DataLoader(ds_val,   batch_size=batch_size, shuffle=False, num_workers=num_workers,
                          collate_fn=collate_ast_fixed)
test_loader  = DataLoader(ds_test,  batch_size=batch_size, shuffle=False, num_workers=num_workers,
                          collate_fn=collate_ast_fixed)

x0, y0, f0 = next(iter(train_loader))
print("batch input:", x0.shape, "labels:", y0.shape, "example fn:", f0[0])


batch input: torch.Size([32, 1024, 128]) labels: torch.Size([32]) example fn: 3-100018-A-18.wav


## 6) 学習ループ（fine-tune → best checkpoint 保存）

In [None]:
from sklearn.metrics import f1_score, accuracy_score
from tqdm.auto import tqdm

def run_eval(model, loader, device):
    model.eval()
    ys, ps = [], []
    total_loss = 0.0
    n = 0
    with torch.no_grad():
        for x, y, _ in loader:
            x = x.to(device)
            y = y.to(device)
            out = model(input_values=x, labels=y)
            loss = out.loss
            logits = out.logits
            pred = torch.argmax(logits, dim=1)

            total_loss += float(loss.item()) * y.size(0)
            n += y.size(0)
            ys.append(y.detach().cpu().numpy())
            ps.append(pred.detach().cpu().numpy())

    y_true = np.concatenate(ys)
    y_pred = np.concatenate(ps)
    return {
        "loss": total_loss / max(n, 1),
        "accuracy": float(accuracy_score(y_true, y_pred)),
        "macro_f1": float(f1_score(y_true, y_pred, average="macro", zero_division=0)),
        "weighted_f1": float(f1_score(y_true, y_pred, average="weighted", zero_division=0)),
    }

def train_ast(
    model,
    train_loader,
    val_loader,
    device,
    *,
    epochs: int = 10,
    lr: float = 1e-5,  # ASTは低め推奨（HF docsでも注意）
    weight_decay: float = 1e-4,
    grad_clip: float = 1.0,
    mixed_precision: bool = True,
    out_dir: Path | None = None,
):
    out_dir = out_dir or (ROOT / "reports" / f"ast_{time.strftime('%Y%m%d-%H%M%S')}")
    out_dir.mkdir(parents=True, exist_ok=True)
    print("out_dir:", out_dir)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scaler = torch.amp.GradScaler("cuda", enabled=(mixed_precision and device.type == "cuda"))

    best = {"macro_f1": -1.0, "epoch": -1}
    history = []

    for epoch in range(1, epochs + 1):
        model.train()
        pbar = tqdm(train_loader, desc=f"train epoch {epoch}", leave=False)
        running = 0.0
        n = 0

        for x, y, _ in pbar:
            x = x.to(device)
            y = y.to(device)

            optimizer.zero_grad(set_to_none=True)
            with torch.amp.autocast("cuda", enabled=(mixed_precision and device.type == "cuda")):
                out = model(input_values=x, labels=y)
                loss = out.loss

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            scaler.step(optimizer)
            scaler.update()

            running += float(loss.item()) * y.size(0)
            n += y.size(0)
            pbar.set_postfix(loss=running / max(n, 1))

        val_stats = run_eval(model, val_loader, device)
        row = {"epoch": epoch, "train_loss": running / max(n, 1), **{f"val_{k}": v for k, v in val_stats.items()}}
        history.append(row)
        print(f"Epoch {epoch}: train_loss={row['train_loss']:.4f} | val_loss={val_stats['loss']:.4f} | val_macro_f1={val_stats['macro_f1']:.4f} | val_acc={val_stats['accuracy']:.4f}")

        if val_stats["macro_f1"] > best["macro_f1"]:
            best = {"macro_f1": val_stats["macro_f1"], "epoch": epoch}
            ckpt = out_dir / "best.pt"
            torch.save({"model_state_dict": model.state_dict(), "epoch": epoch, "val": val_stats}, ckpt)

    torch.save({"model_state_dict": model.state_dict(), "epoch": epochs}, out_dir / "last.pt")
    pd.DataFrame(history).to_csv(out_dir / "train_log.csv", index=False)
    print("best:", best)
    return out_dir

# ---- Train ----
EPOCHS = int(cfg["train"].get("epochs", 10))
LR = float(cfg["train"].get("lr", 1e-5))
out_dir = train_ast(model, train_loader, val_loader, device, epochs=EPOCHS, lr=LR)
out_dir


out_dir: C:\Users\hirok\code\02_sound\esc50_cnn\reports\ast_20260118-100246


                                                                         

Epoch 1: train_loss=2.7173 | val_loss=1.7254 | val_macro_f1=0.8365 | val_acc=0.8475


train epoch 2:  45%|████▍     | 17/38 [18:24<22:41, 64.83s/it, loss=1.21]

## 7) Test 推論 → inference_summary.csv → 指標算出

In [None]:
from tqdm.auto import tqdm

def load_checkpoint(model, ckpt_path: Path, device):
    ckpt = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(ckpt["model_state_dict"])
    return ckpt

ckpt_best = out_dir / "best.pt"
_ = load_checkpoint(model, ckpt_best, device)
print("Loaded best:", ckpt_best)

model.eval()
rows = []
with torch.no_grad():
    for x, y, fns in tqdm(test_loader, desc="infer", leave=False):
        x = x.to(device)
        logits = model(input_values=x).logits
        prob = torch.softmax(logits, dim=1).detach().cpu().numpy()
        pred = prob.argmax(axis=1)

        y_np = y.numpy()
        for i, fn in enumerate(fns):
            row = {
                "filename": fn,
                "y_true": int(y_np[i]),
                "y_pred": int(pred[i]),
                "y_prob": float(prob[i, pred[i]]),
            }
            for c in range(prob.shape[1]):
                row[f"p_{c:02d}"] = float(prob[i, c])
            rows.append(row)

summary = pd.DataFrame(rows)
summary_path = out_dir / "inference_summary.csv"
summary.to_csv(summary_path, index=False)
print("Wrote:", summary_path)

y_true = summary["y_true"].to_numpy()
y_pred = summary["y_pred"].to_numpy()
metrics = {
    "accuracy": float(accuracy_score(y_true, y_pred)),
    "macro_f1": float(f1_score(y_true, y_pred, average="macro", zero_division=0)),
    "weighted_f1": float(f1_score(y_true, y_pred, average="weighted", zero_division=0)),
}
print("Metrics:", metrics)


## 8) UMAP: 入力特徴 vs エンコード後特徴（pooler_output）

In [None]:
import umap
import matplotlib.pyplot as plt
import matplotlib as mpl
from tqdm.auto import tqdm

def batch_input_features(loader, max_items: int | None = None):
    # ASTFeatureExtractorの出力 (input_values) から、簡単な要約ベクトルを作る
    # 例: mel binごとに time平均と標準偏差を連結 → 256次元
    Xs, ys, fns = [], [], []
    seen = 0
    for x, y, fn in tqdm(loader, desc="collect input features", leave=False):
        x_np = x.numpy()            # [B, T, M]
        mu = x_np.mean(axis=1)      # [B, M]
        sd = x_np.std(axis=1)       # [B, M]
        feat = np.concatenate([mu, sd], axis=1)  # [B, 2M]
        Xs.append(feat)
        ys.append(y.numpy())
        fns.extend(fn)
        seen += x_np.shape[0]
        if max_items and seen >= max_items:
            break
    return np.concatenate(Xs), np.concatenate(ys), np.array(fns)

def batch_encoded_features(model, loader, device, max_items: int | None = None):
    # AST encoder の pooler_output を集める（[B, hidden]）
    model.eval()
    Xs, ys, fns = [], [], []
    seen = 0
    with torch.no_grad():
        for x, y, fn in tqdm(loader, desc="collect encoded features", leave=False):
            x = x.to(device)
            enc = model.ast(input_values=x)
            emb = enc.pooler_output.detach().cpu().numpy()
            Xs.append(emb)
            ys.append(y.numpy())
            fns.extend(fn)
            seen += emb.shape[0]
            if max_items and seen >= max_items:
                break
    return np.concatenate(Xs), np.concatenate(ys), np.array(fns)

def make_discrete_cmap(num_classes: int, seed: int = 0):
    base = plt.get_cmap("tab20").colors
    colors = (list(base) * ((num_classes // len(base)) + 1))[:num_classes]
    rng = np.random.default_rng(seed)
    perm = rng.permutation(num_classes)
    colors = [colors[i] for i in perm]
    cmap = mpl.colors.ListedColormap(colors)
    norm = mpl.colors.BoundaryNorm(np.arange(-0.5, num_classes + 0.5, 1), cmap.N)
    return cmap, norm

def plot_umap(Z2, y, title, df_meta):
    num_classes = int(df_meta["target"].nunique())
    cmap, norm = make_discrete_cmap(num_classes, seed=int(cfg["project"]["seed"]))
    plt.figure(figsize=(12, 10))
    sc = plt.scatter(Z2[:, 0], Z2[:, 1], c=y, s=8, alpha=0.55, cmap=cmap, norm=norm, linewidths=0)
    cb = plt.colorbar(sc, ticks=np.arange(num_classes))
    cb.set_label("target id")
    cb.ax.tick_params(labelsize=6)
    plt.title(title)
    plt.xlabel("UMAP-1")
    plt.ylabel("UMAP-2")
    plt.tight_layout()
    plt.show()

df_all = pd.concat([splits["train"], splits["val"], splits["test"]], axis=0).reset_index(drop=True)
ds_all = Esc50AstDataset(df_all, audio_dir, feature_extractor)
all_loader = DataLoader(ds_all, batch_size=batch_size, shuffle=False, num_workers=num_workers,
                        collate_fn=collate_ast_fixed)

Xin, y_all, fn_all = batch_input_features(all_loader)
Xenc, _, _         = batch_encoded_features(model, all_loader, device)

um_in  = umap.UMAP(n_neighbors=20, min_dist=0.15, metric="euclidean", random_state=int(cfg["project"]["seed"]))
Z_in   = um_in.fit_transform(Xin)

um_enc = umap.UMAP(n_neighbors=20, min_dist=0.15, metric="euclidean", random_state=int(cfg["project"]["seed"]))
Z_enc  = um_enc.fit_transform(Xenc)

plot_umap(Z_in,  y_all, "UMAP: AST input features (mean+std over time per mel)", df)
plot_umap(Z_enc, y_all, "UMAP: AST encoded features (pooler_output)", df)


## 9) 指定ラベルだけ可視化（ラベル名リストで選択）

In [None]:
def plot_umap_selected(Z2, y, df_meta, labels=None, emphasize=True, title="UMAP selected"):
    t2c = df_meta.sort_values("target").drop_duplicates("target").set_index("target")["category"].to_dict()
    c2t = {c: int(t) for t, c in t2c.items()}
    num_classes = int(df_meta["target"].nunique())

    if labels is None:
        sel_targets = np.arange(num_classes, dtype=int)
        labels_txt = "ALL"
    else:
        missing = [c for c in labels if c not in c2t]
        if missing:
            raise ValueError(f"Unknown category: {missing}")
        sel_targets = np.array([c2t[c] for c in labels], dtype=int)
        labels_txt = ", ".join(labels)

    sel_mask = np.isin(y, sel_targets)
    K = len(sel_targets)

    base = plt.get_cmap("tab20").colors
    colors = (list(base) * ((K // len(base)) + 1))[:K]
    rng = np.random.default_rng(int(cfg["project"]["seed"]))
    perm = rng.permutation(K)
    colors = [colors[i] for i in perm]
    cmap = mpl.colors.ListedColormap(colors)

    target_to_index = {int(t): i for i, t in enumerate(sel_targets.tolist())}
    c_sel = np.array([target_to_index[int(t)] for t in y[sel_mask]], dtype=int)

    plt.figure(figsize=(12, 10))
    if emphasize and labels is not None:
        plt.scatter(Z2[:, 0], Z2[:, 1], c="lightgray", s=8, alpha=0.12, linewidths=0)

    sc = plt.scatter(Z2[sel_mask, 0], Z2[sel_mask, 1], c=c_sel, s=10, alpha=0.65, cmap=cmap, linewidths=0)

    cb = plt.colorbar(sc, ticks=np.arange(K))
    cb.ax.set_yticklabels([t2c[int(t)] for t in sel_targets])
    cb.set_label("category")
    cb.ax.tick_params(labelsize=8)

    plt.title(f"{title} | labels={labels_txt}")
    plt.xlabel("UMAP-1")
    plt.ylabel("UMAP-2")
    plt.tight_layout()
    plt.show()

# 例:
# plot_umap_selected(Z_in,  y_all, df, labels=["dog","rain","chainsaw"], title="UMAP input features")
# plot_umap_selected(Z_enc, y_all, df, labels=["dog","rain","chainsaw"], title="UMAP encoded features")
