In [None]:
from __future__ import annotations

import json
import sys
import time
from pathlib import Path

import torch
from torch.utils.data import DataLoader

from align_utils import (
    cls_align_beats,
    il_align_beats,
    select_acc_dtype,
    select_device,
    evaluate_accuracy_acil,
    init_w_fe
)
from data_utils import ESC50SplitDataset, pad_collate
from downloader import download_esc50
from model_utils import (
    BEATsWithHead,
    expand_classifier,
    load_beats_backbone,
    load_beats_model,
    maybe_resume_checkpoint,
)
from split_esc50 import ESC_ROOT, make_splits

In [2]:
download_esc50(Path("Dataset"))
esc_root = ESC_ROOT
if not esc_root.exists():
    raise FileNotFoundError(f"ESC-50 root not found: {esc_root}")
splits = make_splits(
    esc_root=esc_root,
    test_fold=1,
    seed=2026,
)
out_path = Path("esc50_25_5x5_splits.json")
with out_path.open("w", encoding="utf-8") as f:
    json.dump(splits, f, indent=2)
print(f"Saved splits to {out_path.resolve()}")

ESC-50 already exists at Dataset/ESC-50-master
Loaded 50 targets
Saved splits to /Users/yyy/Desktop/research/BEATS-acil/esc50_25_5x5_splits.json


In [3]:
beats_checkpoint_path = Path("checkpoints/BEATs_iter3_plus_AS2M.pt")
if beats_checkpoint_path.exists():
            device = select_device("mps")
            beats = load_beats_model(beats_checkpoint_path, device).to(device)
            print(f"BEATs model loaded from: {beats_checkpoint_path}")
            train_dataset = ESC50SplitDataset(
                splits=splits,
                audio_dir=esc_root / "audio",
                use_split="train",
            )
            train_loader = DataLoader(
                train_dataset,
                batch_size=4,
                shuffle=True,
                num_workers=0,
                collate_fn=pad_collate,
            )
            val_dataset = ESC50SplitDataset(
                splits=splits,
                audio_dir=esc_root / "audio",
                use_split="test",
            )
            val_loader = DataLoader(
                val_dataset,
                batch_size=4,
                shuffle=False,
                num_workers=0,
                collate_fn=pad_collate,
            )
            #将.wav变成了fbank用于训练，加上了分类头[num_classes 25],beats自带的函数
            model = BEATsWithHead(beats, num_classes=25).to(device)
            print(model)
            model.train()
            print(f"Training mode: {model.training}, device: {device}")

  WeightNorm.apply(module, name, dim)


BEATs model loaded from: checkpoints/BEATs_iter3_plus_AS2M.pt
BEATsWithHead(
  (beats): BEATs(
    (post_extract_proj): Linear(in_features=512, out_features=768, bias=True)
    (patch_embedding): Conv2d(1, 512, kernel_size=(16, 16), stride=(16, 16), bias=False)
    (dropout_input): Dropout(p=0.1, inplace=False)
    (encoder): TransformerEncoder(
      (pos_conv): Sequential(
        (0): Conv1d(768, 768, kernel_size=(128,), stride=(1,), padding=(64,), groups=16)
        (1): SamePad()
        (2): GELU(approximate='none')
      )
      (layers): ModuleList(
        (0): TransformerSentenceEncoderLayer(
          (self_attn): MultiheadAttention(
            (dropout_module): Dropout(p=0.1, inplace=False)
            (relative_attention_bias): Embedding(320, 12)
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
    

In [4]:
origckpt = torch.load(Path("checkpoints/BEATs_iter3_plus_AS2M.pt"), map_location="cpu")
print(type(origckpt))
print(origckpt.keys())
print("num_classes:", origckpt.get("num_classes"))

<class 'dict'>
dict_keys(['cfg', 'model'])
num_classes: None


In [5]:
state = origckpt["model"]
for i, (k, v) in enumerate(state.items()):
    print(k, tuple(v.shape), v.dtype)
    #原BEATs的参数


post_extract_proj.weight (768, 512) torch.float32
post_extract_proj.bias (768,) torch.float32
patch_embedding.weight (512, 1, 16, 16) torch.float32
encoder.pos_conv.0.bias (768,) torch.float32
encoder.pos_conv.0.weight_g (1, 1, 128) torch.float32
encoder.pos_conv.0.weight_v (768, 48, 128) torch.float32
encoder.layers.0.self_attn.relative_attention_bias.weight (320, 12) torch.float32
encoder.layers.0.self_attn.k_proj.weight (768, 768) torch.float32
encoder.layers.0.self_attn.k_proj.bias (768,) torch.float32
encoder.layers.0.self_attn.v_proj.weight (768, 768) torch.float32
encoder.layers.0.self_attn.v_proj.bias (768,) torch.float32
encoder.layers.0.self_attn.q_proj.weight (768, 768) torch.float32
encoder.layers.0.self_attn.q_proj.bias (768,) torch.float32
encoder.layers.0.self_attn.out_proj.weight (768, 768) torch.float32
encoder.layers.0.self_attn.out_proj.bias (768,) torch.float32
encoder.layers.0.self_attn_layer_norm.weight (768,) torch.float32
encoder.layers.0.self_attn_layer_norm.bi

In [6]:
optimizer = torch.optim.Adam(
    [
        {"params": model.beats.parameters(), "lr": 1e-5},
        {"params": model.classifier.parameters(), "lr": 1e-3},
    ]
)
loss_fn = torch.nn.CrossEntropyLoss()

In [7]:
for epoch in range(1, 5 + 1):
    epoch_start = time.perf_counter()
    total_loss = 0.0
    num_batches = 0
    for audio, padding_mask, targets in train_loader:
        audio = audio.to(device)
        padding_mask = padding_mask.to(device)
        targets = targets.to(device)

        optimizer.zero_grad()#梯度清零
        logits = model(audio, padding_mask)
        loss = loss_fn(logits, targets)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        num_batches += 1
    print(num_batches) #200

    avg_loss = total_loss / max(num_batches, 1)

    model.eval()
    print(f"Eval mode: {model.training} (should be False)")
    correct = 0
    total = 0
    all_preds, all_targets = [], []
    with torch.no_grad():
        for audio, padding_mask, targets in val_loader:
            audio = audio.to(device)
            padding_mask = padding_mask.to(device)
            targets = targets.to(device)
            logits = model(audio, padding_mask)
            preds = logits.argmax(dim=1)
            correct += (preds == targets).sum().item()
            total += targets.numel()
            all_preds.append(preds)
            all_targets.append(targets)
        print(all_preds)
        print(all_targets)
    acc = correct / max(total, 1)
    epoch_time = time.perf_counter() - epoch_start
    print(
        f"Epoch {epoch}/5 - loss: {avg_loss:.4f} - val_acc: {acc:.4f} "
        f"- time: {epoch_time:.1f}s"
    )
    model.train()
    print(f"Training mode: {model.training} (should be True)")

ckpt_out = Path("checkpoints/beats_finetuned_base25.pt")
ckpt_out.parent.mkdir(parents=True, exist_ok=True)
torch.save(
    {
        "cfg": model.beats.cfg.__dict__,
        "model": model.state_dict(),
        "num_classes": 25,
    },
    ckpt_out,
)
print(f"Saved fine-tuned checkpoint to: {ckpt_out}")

backbone_out = Path("checkpoints/beats_base25_backbone.pt")
torch.save(
    {
        "cfg": model.beats.cfg.__dict__,
        "model": model.beats.state_dict(),
    },
    backbone_out,
)
print(f"Saved BEATs backbone checkpoint to: {backbone_out}")


  spectrum = torch.fft.rfft(strided_input).abs()


200
Eval mode: False (should be False)
[tensor([11, 11, 11, 11], device='mps:0'), tensor([11, 11, 11, 11], device='mps:0'), tensor([8, 8, 8, 8], device='mps:0'), tensor([8, 8, 8, 8], device='mps:0'), tensor([24, 24, 24, 24], device='mps:0'), tensor([24, 24, 24, 24], device='mps:0'), tensor([4, 4, 4, 4], device='mps:0'), tensor([4, 4, 4, 4], device='mps:0'), tensor([11, 16, 16, 16], device='mps:0'), tensor([16, 16, 16, 16], device='mps:0'), tensor([14, 14, 14, 14], device='mps:0'), tensor([14, 14, 14, 14], device='mps:0'), tensor([19, 19, 19, 19], device='mps:0'), tensor([19, 19, 19, 19], device='mps:0'), tensor([13, 13, 13, 13], device='mps:0'), tensor([13, 13, 13, 13], device='mps:0'), tensor([0, 0, 0, 0], device='mps:0'), tensor([0, 0, 0, 0], device='mps:0'), tensor([6, 6, 6, 6], device='mps:0'), tensor([ 6,  6, 18, 22], device='mps:0'), tensor([16, 16, 16, 16], device='mps:0'), tensor([16, 16,  8, 24], device='mps:0'), tensor([9, 9, 9, 9], device='mps:0'), tensor([9, 9, 9, 9], devic

In [7]:
ckpt = torch.load(Path("checkpoints/beats_finetuned_base25.pt"), map_location="mps")
print(type(ckpt))
print(ckpt.keys())
print("num_classes:", ckpt.get("num_classes"))
state = ckpt["model"]
for i, (k, v) in enumerate(state.items()):
    print(k, tuple(v.shape), v.dtype)
    



<class 'dict'>
dict_keys(['cfg', 'model', 'num_classes'])
num_classes: 25
beats.post_extract_proj.weight (768, 512) torch.float32
beats.post_extract_proj.bias (768,) torch.float32
beats.patch_embedding.weight (512, 1, 16, 16) torch.float32
beats.encoder.pos_conv.0.bias (768,) torch.float32
beats.encoder.pos_conv.0.weight_g (1, 1, 128) torch.float32
beats.encoder.pos_conv.0.weight_v (768, 48, 128) torch.float32
beats.encoder.layers.0.self_attn.grep_a (1, 12, 1, 1) torch.float32
beats.encoder.layers.0.self_attn.relative_attention_bias.weight (320, 12) torch.float32
beats.encoder.layers.0.self_attn.k_proj.weight (768, 768) torch.float32
beats.encoder.layers.0.self_attn.k_proj.bias (768,) torch.float32
beats.encoder.layers.0.self_attn.v_proj.weight (768, 768) torch.float32
beats.encoder.layers.0.self_attn.v_proj.bias (768,) torch.float32
beats.encoder.layers.0.self_attn.q_proj.weight (768, 768) torch.float32
beats.encoder.layers.0.self_attn.q_proj.bias (768,) torch.float32
beats.encoder.la

In [8]:
w = ckpt["model"]["classifier.weight"]  # 如果key不存在会报错
print(w.shape)
print(w[:2, :8])  # 只打印一小块


torch.Size([25, 768])
tensor([[ 0.0321,  0.0500, -0.0302,  0.0999,  0.0265,  0.0260, -0.0355,  0.0527],
        [-0.0303,  0.0814, -0.0121,  0.0439, -0.0198,  0.0510,  0.0005,  0.1116]],
       device='mps:0')


In [21]:
backbone_ckpt_path = Path("checkpoints/beats_base25_backbone.pt")
if backbone_ckpt_path.exists():
    device = select_device("mps")
    beats = load_beats_backbone(backbone_ckpt_path, device)
    beats.to(device)
    beats.eval()
    for p in beats.parameters():
        p.requires_grad = False
    print(f"BEATs backbone loaded from: {backbone_ckpt_path}")
    base_classes = list(splits["base_classes"])[:25]
    incremental_groups = splits["incremental_classes"][:5]
    if incremental_groups:
        for idx, group in enumerate(incremental_groups, start=1):
            count = len(base_classes) + sum(
                len(g) for g in incremental_groups[:idx]
            )
            print(f"Seen classes after phase {idx}: {count}")
    model = BEATsWithHead(beats, num_classes=len(base_classes)).to(device)
    model.eval()
    print(f"Alignment mode: {model.training}, device: {device}")
    maybe_resume_checkpoint(model, "", device)

    base_class_to_idx = {cid: idx for idx, cid in enumerate(base_classes)}

    base_train_dataset = ESC50SplitDataset(
        splits=splits,
        audio_dir=esc_root / "audio",
        use_split="train",
        class_ids=base_classes,
        class_to_idx=base_class_to_idx,
    )
    base_train_loader = DataLoader(
        base_train_dataset,
        batch_size=4,
        shuffle=True,
        num_workers=0,
        collate_fn=pad_collate,
    )
    base_val_dataset = ESC50SplitDataset(
        splits=splits,
        audio_dir=esc_root / "audio",
        use_split="test",
        class_ids=base_classes,
        class_to_idx=base_class_to_idx,
    )
    base_val_loader = DataLoader(
        base_val_dataset,
        batch_size=4,
        shuffle=False,
        num_workers=0,
        collate_fn=pad_collate,
    )
acc_cil: list[float] = []
forget_rate: list[float] = []
print("num_classes:", len(base_classes))
print("classifier.weight(before):", model.classifier.weight.shape)
#增量学习前评估beats的准确度并初始化R
W_fe = init_w_fe(768, 4000, device, dtype=select_acc_dtype(device)) 
R, W = cls_align_beats(
    train_loader=base_train_loader,
    model=model,
    device=device,
    num_classes=len(base_classes),
    rg=1e-3,
    W_fe = W_fe
)
print("R shape:", R.shape)
print("classifier.weight(after):", model.classifier.weight.shape)
base_acc = evaluate_accuracy_acil(base_val_loader, model, device, W, W_fe)
acc_cil.append(base_acc)
print(f"Base phase acc: {base_acc:.4f}")

BEATs backbone loaded from: checkpoints/beats_base25_backbone.pt
Seen classes after phase 1: 30
Seen classes after phase 2: 35
Seen classes after phase 3: 40
Seen classes after phase 4: 45
Seen classes after phase 5: 50
Alignment mode: False, device: mps
num_classes: 25
classifier.weight(before): torch.Size([25, 768])
R shape: torch.Size([4000, 4000])
classifier.weight(after): torch.Size([25, 4000])
Base phase acc: 0.9650


In [16]:
for phase_idx, inc_classes in enumerate(incremental_groups, start=1):
    print(f"Phase {phase_idx} classes: {inc_classes}")
    new_num_classes = len(base_classes) + sum(
        len(g) for g in incremental_groups[:phase_idx]
    )
    print("new_num_classes:", new_num_classes)
    expand_classifier(model, new_num_classes)
    print(f"classifier.weight(Phase {phase_idx}):", model.classifier.weight.shape)

    # 同步扩展 W: [fe_dim, old_C] -> [fe_dim, new_C]
    if W.size(1) < new_num_classes:
        W_new = torch.zeros(
            W.size(0), new_num_classes,
            device=W.device, dtype=W.dtype
        )
        W_new[:, :W.size(1)] = W
        W = W_new

    print("W shape after expand:", W.shape)  # 例如 [2000, 30]

    seen_classes = list(base_classes)
    for g in incremental_groups[:phase_idx]:
        seen_classes.extend(g)
    phase_class_to_idx = {cid: idx for idx, cid in enumerate(seen_classes)}
    print(f"seen_classes (phase {phase_idx}, n={len(seen_classes)}):")

    inc_train_dataset = ESC50SplitDataset(
        splits=splits,
        audio_dir=esc_root / "audio",
        use_split="train",
        class_ids=inc_classes,
        class_to_idx=phase_class_to_idx,
    )
    inc_train_loader = DataLoader(
        inc_train_dataset,
        batch_size=4,
        shuffle=True,
        num_workers=0,
        collate_fn=pad_collate,
    )

    if isinstance(R, tuple):
        if len(R) == 2:
            R, W = R
        else:
            raise ValueError(f"Unexpected R tuple length: {len(R)}")

    R, W = il_align_beats(
        train_loader=inc_train_loader,
        model=model,
        device=device,
        num_classes=new_num_classes,
        R=R,
        W=W,
        repeat=1,
        W_fe=W_fe
    )


    val_dataset = ESC50SplitDataset(
        splits=splits,
        audio_dir=esc_root / "audio",
        use_split="test",
        class_ids=seen_classes,
        class_to_idx=phase_class_to_idx,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=4,
        shuffle=False,
        num_workers=0,
        collate_fn=pad_collate,
    )

    print("val classes:", len(val_dataset.class_to_idx))
    acc = evaluate_accuracy_acil(val_loader, model, device, W, W_fe)
    acc_cil.append(acc)

    #验证增量后beats阶段的25类数据的准确度
    base_acc_now = evaluate_accuracy_acil(base_val_loader, model, device, W, W_fe)
    forget = acc_cil[0] - base_acc_now
    forget_rate.append(forget)

    print(
        f"Phase {phase_idx}/{len(incremental_groups)} "
        f"- acc: {acc:.4f} "
        f"- base_now: {base_acc_now:.4f} "
        f"- forget: {forget:.4f}"
    )

if acc_cil:
    avg = sum(acc_cil) / len(acc_cil)
    print(f"Average accuracy: {avg:.4f}")
else:
    print("Average accuracy: n/a (no phases evaluated)")

ckpt_out = Path("checkpoints/beats_base25_incremental.pt")
ckpt_out.parent.mkdir(parents=True, exist_ok=True)
torch.save(
    {
        "cfg": model.beats.cfg.__dict__,
        "model": model.state_dict(),
        "num_classes": model.classifier.out_features,
    },
    ckpt_out,
)
print(f"Saved incremental checkpoint to: {ckpt_out}")



Phase 1 classes: [21, 33, 43, 3, 45]
new_num_classes: 30
classifier.weight(Phase 1): torch.Size([30, 4000])
W shape after expand: torch.Size([4000, 30])
seen_classes (phase 1, n=30):
torch.Size([4000, 30])
torch.Size([4000, 4000])
val classes: 30
Phase 1/5 - acc: 0.9583 - base_now: 0.9550 - forget: 0.0100
Phase 2 classes: [9, 49, 5, 0, 15]
new_num_classes: 35
classifier.weight(Phase 2): torch.Size([35, 4000])
W shape after expand: torch.Size([4000, 35])
seen_classes (phase 2, n=35):
torch.Size([4000, 35])
torch.Size([4000, 4000])
val classes: 35
Phase 2/5 - acc: 0.9643 - base_now: 0.9550 - forget: 0.0100
Phase 3 classes: [28, 31, 40, 36, 26]
new_num_classes: 40
classifier.weight(Phase 3): torch.Size([40, 4000])
W shape after expand: torch.Size([4000, 40])
seen_classes (phase 3, n=40):
torch.Size([4000, 40])
torch.Size([4000, 4000])
val classes: 40
Phase 3/5 - acc: 0.9625 - base_now: 0.9550 - forget: 0.0100
Phase 4 classes: [35, 39, 38, 14, 6]
new_num_classes: 45
classifier.weight(Phase