
# Qwen2‑Audio Sparse Autoencoder (SAE) — End‑to‑End Analysis Notebook

> **Goal**: Reconstruct the full pipeline you were running to analyze Qwen2‑Audio representations using OpenSAE‑style sparse autoencoders, and evaluate interpretability (monosemanticity, selectivity), alignment with phoneme/word labels (via MFA TextGrids), seed consistency, capacity sweeps, and intervention experiments.

**Highlights**
- Extract frame‑level **audio** representations from `Qwen2AudioForConditionalGeneration`'s `multi_modal_projector` output.
- (Optional) Extract **text** token embeddings from the text encoder for cross‑modal comparisons.
- Train **OpenSAE‑style** sparse autoencoders on audio (and later text) representations.
- Evaluate: reconstruction quality, sparsity, UMAP/PCA, monosemanticity, phoneme/word correlations, seed consistency, capacity vs. loss.
- **Interventions**: mask/boost single SAE units, pass through the generation stack, observe output differences.
- Utilities for **MFA TextGrid** parsing & alignment.

> Tip: Run this notebook on your machine or HPC environment with GPU. Adjust `DATA_ROOT`, `AUDIO_GLOB`, and `TRANS_GLOB` below to point to your dataset.


## 0. Environment & Paths

In [None]:

# If running locally, uncomment to install needed packages.
# !pip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
# !pip install -U transformers accelerate datasets einops umap-learn matplotlib pandas numpy scipy tqdm sacrebleu textgrid librosa torchaudio scikit-learn networkx
# !pip install -U rich loguru

import os, sys, math, json, random, time, pathlib, glob, shutil
from pathlib import Path
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio

import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = (6,4)
plt.rcParams['figure.dpi'] = 140

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

# ---- Project paths (EDIT THESE) ----
PROJECT_ROOT = Path.cwd()
DATA_ROOT    = PROJECT_ROOT / "data"          # Your audio + transcript + TextGrid live here
AUDIO_GLOB   = str(DATA_ROOT / "audio/**/*.wav")   # change as needed
TRANS_GLOB   = str(DATA_ROOT / "transcripts/**/*.txt")  # optional
TEXTGRID_GLOB= str(DATA_ROOT / "textgrid/**/*.TextGrid") # optional

OUT_DIR      = PROJECT_ROOT / "outputs"
CKPT_DIR     = OUT_DIR / "checkpoints"
FIG_DIR      = OUT_DIR / "figs"
LOG_DIR      = OUT_DIR / "logs"
for d in [OUT_DIR, CKPT_DIR, FIG_DIR, LOG_DIR]:
    d.mkdir(parents=True, exist_ok=True)

SEED = 1337
torch.manual_seed(SEED); np.random.seed(SEED); random.seed(SEED)


## 1. Load Qwen2‑Audio & Processor

In [None]:

from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration

QWEN_AUDIO_MODEL = os.environ.get("QWEN_AUDIO_MODEL", "Qwen/Qwen2-Audio-7B-Instruct")  # example id; change as needed

processor = AutoProcessor.from_pretrained(QWEN_AUDIO_MODEL, trust_remote_code=True)
model = Qwen2AudioForConditionalGeneration.from_pretrained(
    QWEN_AUDIO_MODEL,
    torch_dtype=torch.float16 if DEVICE=="cuda" else torch.float32,
    low_cpu_mem_usage=True,
    device_map="auto" if DEVICE=="cuda" else None,
    trust_remote_code=True
)

model.eval()
print("Loaded:", QWEN_AUDIO_MODEL)
print("Has multi_modal_projector:", hasattr(model, "multi_modal_projector"))


## 2. Audio Dataset & Representation Extraction

In [None]:

import librosa

SR = 16000
MAX_SEC = 20.0

def load_wav(path, sr=SR, max_sec=MAX_SEC):
    wav, in_sr = librosa.load(path, sr=None, mono=True)
    if in_sr != sr:
        wav = librosa.resample(wav, orig_sr=in_sr, target_sr=sr)
    if max_sec is not None:
        wav = wav[:int(max_sec*sr)]
    return torch.tensor(wav, dtype=torch.float32)

def list_audio():
    files = glob.glob(AUDIO_GLOB, recursive=True)
    files = sorted(files)
    return files

def extract_audio_reps(wav_tensor):
    """Return a dict of frame-level and pooled representations."""
    inputs = processor(audios=wav_tensor.numpy(), sampling_rate=SR, return_tensors="pt")
    inputs = {k: v.to(DEVICE) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True, return_dict=True)

    audio_proj = None
    if hasattr(outputs, "audio_hidden_states") and outputs.audio_hidden_states is not None:
        audio_last = outputs.audio_hidden_states[-1].detach().float().cpu()
        if hasattr(model, "multi_modal_projector"):
            with torch.no_grad():
                audio_proj = model.multi_modal_projector(audio_last.to(DEVICE)).detach().float().cpu()
        else:
            audio_proj = audio_last
    else:
        hs = outputs.hidden_states
        if isinstance(hs, (list, tuple)) and len(hs) > 0:
            audio_proj = hs[0].detach().float().cpu()
        else:
            raise RuntimeError("Could not find audio hidden states; adjust extraction.")

    audio_proj = audio_proj.squeeze(0)  # [T_frames, D]
    pooled = audio_proj.mean(dim=0)
    return {"audio_proj_frames": audio_proj, "pooled": pooled, "extra": {}}

audio_files = list_audio()[:3]
print("Found audio files:", len(audio_files))
if audio_files:
    rep = extract_audio_reps(load_wav(audio_files[0]))
    print("Frame reps:", rep["audio_proj_frames"].shape, "Pooled:", rep["pooled"].shape)


### (Optional) Text Token Embeddings for Cross‑Modal SAE

In [None]:

from transformers import AutoTokenizer

tokenizer = getattr(processor, "tokenizer", None)
if tokenizer is None:
    tokenizer = AutoTokenizer.from_pretrained(QWEN_AUDIO_MODEL, trust_remote_code=True)

def extract_text_reps(text):
    inputs = tokenizer(text, return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        if hasattr(model, "language_model") and hasattr(model.language_model, "model"):
            out = model.language_model.model(**inputs, output_hidden_states=True, return_dict=True)
        else:
            out = model(**inputs, output_hidden_states=True, return_dict=True)
    last = out.hidden_states[-1].detach().float().cpu()   # [1, T, D]
    emb = last.squeeze(0)                                 # [T, D]
    pooled = emb.mean(dim=0)
    return {"text_frames": emb, "pooled": pooled}


## 3. Build a Representation Cache

In [None]:

CACHE_DIR = OUT_DIR / "rep_cache"
CACHE_DIR.mkdir(parents=True, exist_ok=True)

def build_audio_cache(limit=None, stride=1):
    files = list_audio()
    if limit is not None:
        files = files[:limit]
    rows = []
    for ix, path in enumerate(tqdm(files, desc="Extracting audio reps")):
        if ix % stride != 0:
            continue
        wav = load_wav(path)
        rep = extract_audio_reps(wav)
        out_pt = CACHE_DIR / (Path(path).stem + ".pt")
        torch.save({"frames": rep["audio_proj_frames"], "pooled": rep["pooled"], "path": path}, out_pt)
        rows.append({"path": path, "pt": str(out_pt), "T": rep["audio_proj_frames"].shape[0]})
    meta = pd.DataFrame(rows)
    meta_path = CACHE_DIR / "audio_meta.csv"
    meta.to_csv(meta_path, index=False)
    print("Wrote:", meta_path, "n =", len(meta))


## 4. OpenSAE‑Style Sparse Autoencoder

In [None]:

class OpenSAE(nn.Module):
    def __init__(self, d_in, d_hidden, topk=None, l1_coef=5e-4, pre_bias=True, unit_norm_decoder=True):
        super().__init__()
        self.d_in = d_in
        self.d_hidden = d_hidden
        self.topk = topk
        self.l1_coef = l1_coef
        self.pre_bias = nn.Parameter(torch.zeros(d_in)) if pre_bias else None

        self.E = nn.Parameter(torch.empty(d_hidden, d_in).normal_(0, 0.02))
        self.D = nn.Parameter(torch.empty(d_in, d_hidden).normal_(0, 0.02))
        self.unit_norm_decoder = unit_norm_decoder
        self.register_buffer("dead_mask", torch.ones(d_hidden, dtype=torch.bool))

    def encode(self, x):
        if self.pre_bias is not None:
            x = x + self.pre_bias
        h = F.linear(x, self.E)  # [B, H]
        h = F.relu(h)
        if self.topk is not None:
            k = min(self.topk, h.shape[-1])
            topk_vals, topk_idx = torch.topk(h, k=k, dim=-1)
            mask = torch.zeros_like(h)
            mask.scatter_(dim=-1, index=topk_idx, src=torch.ones_like(topk_vals))
            h = h * mask
        return h

    def decode(self, h):
        D = self.D
        if self.unit_norm_decoder:
            D = F.normalize(D, dim=0)
        x_hat = F.linear(h, D.t())
        if self.pre_bias is not None:
            x_hat = x_hat - self.pre_bias
        return x_hat

    def forward(self, x):
        h = self.encode(x)
        x_hat = self.decode(h)
        return x_hat, h

    def loss(self, x, x_hat, h):
        rec = F.mse_loss(x_hat, x)
        l1 = h.abs().sum(dim=-1).mean()
        return rec + self.l1_coef * l1, {"rec": rec.detach(), "l1": l1.detach()}

    @torch.no_grad()
    def resample_dead_units(self, h_batch, thresh=1e-6):
        mean_act = (h_batch.abs() > thresh).float().mean(dim=0)  # [H]
        dead = mean_act < 1e-4
        num_dead = int(dead.sum().item())
        if num_dead > 0:
            self.E.data[dead] = torch.empty_like(self.E.data[dead]).normal_(0, 0.02)
            self.D.data[:, dead] = torch.empty_like(self.D.data[:, dead]).normal_(0, 0.02)
        self.dead_mask = ~dead
        return num_dead


### 4.1 Training Utilities

In [None]:

def iter_rep_batches(meta_csv, batch_size=512, frame_sampling="uniform", frames_per_item=4, device=DEVICE):
    df = pd.read_csv(meta_csv)
    frames = []
    for _, row in df.iterrows():
        d = torch.load(row["pt"])
        X = d["frames"]  # [T, D]
        T = X.shape[0]
        if frame_sampling == "uniform":
            idx = torch.linspace(0, T-1, steps=min(frames_per_item, T)).long()
        else:
            idx = torch.randint(0, T, (min(frames_per_item, T),))
        frames.append(X[idx])  # [m, D]
        if sum(z.shape[0] for z in frames) >= batch_size:
            batch = torch.cat(frames, dim=0)[:batch_size]  # [B, D]
            yield batch.to(device)
            frames = []
    if frames:
        batch = torch.cat(frames, dim=0)
        yield batch.to(device)

def train_sae(meta_csv, d_in, d_hidden=8192, topk=64, l1=5e-4, lr=1e-3, steps=10000, log_every=100, save_every=1000, tag="audio_sae"):
    sae = OpenSAE(d_in=d_in, d_hidden=d_hidden, topk=topk, l1_coef=l1).to(DEVICE)
    opt = torch.optim.AdamW(sae.parameters(), lr=lr)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=steps)

    log = []
    giter = iter_rep_batches(meta_csv, batch_size=1024, frames_per_item=8)
    for step in range(1, steps+1):
        try:
            x = next(giter)
        except StopIteration:
            giter = iter_rep_batches(meta_csv, batch_size=1024, frames_per_item=8)
            x = next(giter)

        x_hat, h = sae(x)
        loss, parts = sae.loss(x, x_hat, h)

        opt.zero_grad(set_to_none=True)
        loss.backward()
        nn.utils.clip_grad_norm_(sae.parameters(), max_norm=1.0)
        opt.step(); sched.step()

        if step % 250 == 0:
            with torch.no_grad():
                sae.resample_dead_units(h)

        if step % log_every == 0:
            msg = {
                "step": step, "loss": float(loss.item()),
                "rec": float(parts["rec"].item()), "l1": float(parts["l1"].item()),
                "lr": float(opt.param_groups[0]["lr"])
            }
            log.append(msg)
            print(msg)

        if step % save_every == 0 or step == steps:
            ckpt_path = CKPT_DIR / f"{tag}_step{step}.pt"
            torch.save({"sae": sae.state_dict(), "cfg": {
                "d_in": d_in, "d_hidden": d_hidden, "topk": topk, "l1": l1
            }}, ckpt_path)
            print("Saved:", ckpt_path)

    pd.DataFrame(log).to_csv(OUT_DIR / f"trainlog_{tag}.csv", index=False)
    return sae


## 5. Evaluation — Reconstruction & Sparsity

In [None]:

def eval_reconstruction(sae, meta_csv, n_batches=20):
    sae.eval()
    recs, l1s, act_frac = [], [], []
    with torch.no_grad():
        it = iter_rep_batches(meta_csv, batch_size=2048, frames_per_item=16)
        for i, x in zip(range(n_batches), it):
            x_hat, h = sae(x)
            recs.append(F.mse_loss(x_hat, x).item())
            l1s.append(h.abs().sum(dim=-1).mean().item())
            act_frac.append((h>0).float().mean().item())
    return {"rec_mse": float(np.mean(recs)), "l1_mean": float(np.mean(l1s)), "active_frac": float(np.mean(act_frac))}


## 6. Visualization — PCA / UMAP of Sparse Codes

In [None]:

from sklearn.decomposition import PCA
import umap

def sample_codes(sae, meta_csv, max_samples=10000):
    codes = []
    with torch.no_grad():
        for x in iter_rep_batches(meta_csv, batch_size=2048, frames_per_item=16):
            _, h = sae(x)
            codes.append(h.detach().cpu())
            if sum(z.shape[0] for z in codes) >= max_samples:
                break
    H = torch.cat(codes, dim=0).numpy()
    return H

def plot_pca_umap(H, title="Sparse Codes"):
    pca2 = PCA(n_components=2).fit_transform(H)
    reducer = umap.UMAP(n_components=2, n_neighbors=30, min_dist=0.1, metric="cosine", random_state=SEED)
    um2 = reducer.fit_transform(H)

    fig, axs = plt.subplots(1, 2, figsize=(10,4))
    axs[0].scatter(pca2[:,0], pca2[:,1], s=2)
    axs[0].set_title("PCA (2D)")
    axs[1].scatter(um2[:,0], um2[:,1], s=2)
    axs[1].set_title("UMAP (2D)")
    fig.suptitle(title); plt.show()


## 7. MFA TextGrid Alignment — Phoneme/Word Mapping

In [None]:

from textgrid import TextGrid

def parse_textgrid(tg_path):
    tg = TextGrid.fromFile(tg_path)
    tiers = {t.name.lower(): t for t in tg.tiers}
    phone_tier = next((tiers[k] for k in tiers if "phone" in k), None)
    word_tier  = next((tiers[k] for k in tiers if "word" in k), None)
    phones = [(i.minTime, i.maxTime, i.mark) for i in phone_tier.intervals] if phone_tier else []
    words  = [(i.minTime, i.maxTime, i.mark) for i in word_tier.intervals] if word_tier else []
    return phones, words

def time_to_frame_idx(start, end, sr=SR, hop=320):
    s = int(round(start * sr / hop))
    e = int(round(end   * sr / hop))
    return s, e

def unit_selectivity_to_phonemes(sae, rep_pt, phones, hop=320):
    d = torch.load(rep_pt)
    X = d["frames"]             # [T, D]
    H = sae.encode(X.to(DEVICE)).detach().cpu().numpy()  # [T, K]
    results = []
    for (t0, t1, ph) in phones:
        s, e = time_to_frame_idx(t0, t1, hop=hop)
        s = max(s, 0); e = min(e, H.shape[0])
        if e <= s:
            continue
        h_seg = H[s:e]          # [len, K]
        avg = h_seg.mean(axis=0)  # [K]
        topk_idx = np.argpartition(-avg, 10)[:10]
        results.append({"phoneme": ph, "top_units": topk_idx.tolist(), "avg_top": avg[topk_idx].tolist()})
    return results


## 8. Monosemanticity & Dictionary Inspection

In [None]:

def decoder_atoms(sae):
    D = sae.D.detach().cpu()
    if sae.unit_norm_decoder:
        D = F.normalize(D, dim=0)
    return D.numpy()  # [d_in, H]

def unit_top_inputs(sae, meta_csv, unit_idx, n=2000, topm=30):
    scores, frames = [], []
    with torch.no_grad():
        for x in iter_rep_batches(meta_csv, batch_size=2048, frames_per_item=16):
            h = sae.encode(x)
            s = h[:, unit_idx].detach().cpu()
            scores.append(s); frames.append(x.detach().cpu())
            if sum(len(z) for z in scores) >= n:
                break
    S = torch.cat(scores, dim=0)
    X = torch.cat(frames, dim=0)
    idx = torch.topk(S, k=min(topm, len(S))).indices
    return X[idx], S[idx]


## 9. Seed Consistency — Feature Matching Across Training Runs

In [None]:

import numpy as np
from scipy.optimize import linear_sum_assignment

def cosine_sim_matrix(A, B):
    A = torch.from_numpy(A).T
    B = torch.from_numpy(B).T
    A = F.normalize(A, dim=1); B = F.normalize(B, dim=1)
    return (A @ B.T).cpu().numpy()

def hungarian_match(sim):
    row_ind, col_ind = linear_sum_assignment(-sim)
    return row_ind, col_ind, float(sim[row_ind, col_ind].mean())


## 10. Capacity Sweep — d_hidden vs. Reconstruction / Sparsity

In [None]:

def capacity_sweep(meta_csv, d_in, hidden_list=(2048, 4096, 8192, 16384), steps=3000, topk_ratio=0.01):
    results = []
    for H in hidden_list:
        topk = max(4, int(H * topk_ratio))
        sae = train_sae(meta_csv, d_in=d_in, d_hidden=H, topk=topk, steps=steps, tag=f"sweep_H{H}")
        stats = eval_reconstruction(sae, meta_csv, n_batches=20)
        stats.update({"H": H, "topk": topk})
        results.append(stats)
    df = pd.DataFrame(results)
    df.to_csv(OUT_DIR / "capacity_sweep.csv", index=False)
    return df

def plot_capacity_sweep(df):
    fig, ax = plt.subplots()
    ax.plot(df["H"], df["rec_mse"], marker="o")
    ax.set_xscale("log"); ax.set_xlabel("Hidden units (log)")
    ax.set_ylabel("Reconstruction MSE"); ax.set_title("Capacity vs Reconstruction")
    plt.show()


## 11. Intervention Experiments — Mask/Boost Units

In [None]:

def intervene_units(sae, x, mask_units=None, boost_units=None, boost_val=3.0):
    x = x.to(DEVICE)
    with torch.no_grad():
        if sae.pre_bias is not None:
            x_b = x + sae.pre_bias
        else:
            x_b = x
        h = F.relu(F.linear(x_b, sae.E))
        if mask_units:
            h[:, mask_units] = 0.0
        if boost_units:
            h[:, boost_units] = h[:, boost_units] + boost_val
        x_hat = sae.decode(h)
    return x_hat.detach().cpu(), h.detach().cpu()


## 12. (Future) Text‑Modality SAE for Cross‑Modal Comparison

In [None]:

# Once you build a cache of text token embeddings, reuse OpenSAE to train a text-side SAE;
# then compare dictionary atoms and unit statistics between audio and text modalities.


## 13. Runbook — Minimal Steps to Recreate Your Pipeline


1. **Set paths** in the _Environment_ cell (`DATA_ROOT`, `AUDIO_GLOB`, `TEXTGRID_GLOB`).  
2. **Load model** (adjust `QWEN_AUDIO_MODEL` to your local checkpoint if needed).  
3. **Build cache**: `build_audio_cache(limit=..., stride=...)`.  
4. **Train SAE**: determine `d_in` from a cached `.pt`, then run `train_sae(...)`.  
5. **Evaluate**: `eval_reconstruction`, `UMAP/PCA`, monosemanticity utilities.  
6. **Phoneme/word alignment**: parse TextGrid, call `unit_selectivity_to_phonemes`.  
7. **Seed consistency**: train multiple SAEs, Hungarian match decoder atoms.  
8. **Capacity sweep**: run `capacity_sweep`, plot.  
9. **Interventions**: `intervene_units` then, if desired, integrate with the generation stack to inspect output diffs.  
10. **(Optional) Text SAE**: repeat with text embeddings and compare audio vs text features.
