In [1]:
# PRE-TRAINING ON CBAM MODEL

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
!cp -r "/content/drive/MyDrive/IRMAS-TrainingData" /content/

In [4]:
!ls "/content/IRMAS-TrainingData"

cel  cla  flu  gac  gel  org  pia  README.txt  sax  tru  vio  voi


In [5]:
import os, time
import librosa
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [6]:
ROOT_DIR = "/content/IRMAS-TrainingData"
SR = 16000
N_MELS = 128
N_FFT = 1024
HOP_LENGTH = 320
TARGET_FRAMES = 151
BATCH_SIZE = 32
LR = 1e-4
EPOCHS = 15

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

Device: cuda


In [7]:
# Mapping

In [8]:
instrument_to_id = {
    "cel": 0, "cla": 1, "flu": 2, "gac": 3, "gel": 4,
    "pia": 5, "sax": 6, "tru": 7, "vio": 8
}
NUM_CLASSES = len(instrument_to_id)


In [9]:
# IRMASDataset

In [10]:
class IRMASDataset(Dataset):
    def __init__(self, root_dir, instrument_to_id):
        self.files = []
        self.labels = []
        for inst, lab in instrument_to_id.items():
            inst_dir = os.path.join(root_dir, inst)
            if not os.path.isdir(inst_dir):
                raise FileNotFoundError(f"Missing folder: {inst_dir}")
            for f in os.listdir(inst_dir):
                if f.endswith(".wav"):
                    self.files.append(os.path.join(inst_dir, f))
                    self.labels.append(lab)

    def __len__(self):
        return len(self.files)

    def _fix_frames(self, mel, target_frames):
        # mel: [F, T]
        T = mel.shape[1]
        if T == target_frames:
            return mel
        if T > target_frames:
            return mel[:, :target_frames]
        pad = target_frames - T
        # pad sağa (min değerle)
        return F.pad(mel, (0, pad), value=float(mel.min()))

    def __getitem__(self, idx):
        y, _ = librosa.load(self.files[idx], sr=SR, mono=True)
        mel = librosa.feature.melspectrogram(
            y=y, sr=SR, n_mels=N_MELS, n_fft=N_FFT, hop_length=HOP_LENGTH
        )
        mel = librosa.power_to_db(mel)
        mel = torch.tensor(mel, dtype=torch.float32)       # [F,T]
        mel = self._fix_frames(mel, TARGET_FRAMES)         # [F,Tfixed]
        mel = mel.unsqueeze(0)                             # [1,F,T]
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return mel, label


In [11]:
# DataLoader + sanity check

In [12]:
ds = IRMASDataset(ROOT_DIR, instrument_to_id)
dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)

xb, yb = next(iter(dl))
print("Mel:", xb.shape, "Labels:", yb.shape)


Mel: torch.Size([32, 1, 128, 151]) Labels: torch.Size([32])


In [13]:
# CBAM model

In [14]:
class ChannelAttention(nn.Module):
    def __init__(self, in_ch, r=8):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_ch, max(1, in_ch//r)),
            nn.ReLU(),
            nn.Linear(max(1, in_ch//r), in_ch)
        )

    def forward(self, x):
        # x: [B,C,H,W]
        avg = torch.mean(x, dim=(2,3))
        mx  = torch.amax(x, dim=(2,3))
        att = torch.sigmoid(self.mlp(avg) + self.mlp(mx)).unsqueeze(-1).unsqueeze(-1)
        return x * att

class SpatialAttention(nn.Module):
    def __init__(self, k=7):
        super().__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size=k, padding=k//2)

    def forward(self, x):
        avg = torch.mean(x, dim=1, keepdim=True)
        mx  = torch.amax(x, dim=1, keepdim=True)
        att = torch.sigmoid(self.conv(torch.cat([avg, mx], dim=1)))
        return x * att

class CBAM(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.ca = ChannelAttention(ch)
        self.sa = SpatialAttention()

    def forward(self, x):
        return self.sa(self.ca(x))

class CBAM_CNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            CBAM(32),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            CBAM(64),
            nn.AdaptiveAvgPool2d((1,1))
        )
        self.classifier = nn.Linear(64, num_classes)

    def forward(self, x):
        x = self.features(x).flatten(1)
        return self.classifier(x)


In [15]:
# Training loop + kaydetme

In [16]:
model = CBAM_CNN(NUM_CLASSES).to(DEVICE)
crit = nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(), lr=LR)

for ep in range(1, EPOCHS+1):
    model.train()
    total_loss, correct, n = 0.0, 0, 0
    t0 = time.time()

    for xb, yb in dl:
        xb = xb.to(DEVICE, non_blocking=True)
        yb = yb.to(DEVICE, non_blocking=True)

        opt.zero_grad()
        logits = model(xb)
        loss = crit(logits, yb)
        loss.backward()
        opt.step()

        total_loss += loss.item()
        pred = logits.argmax(dim=1)
        correct += (pred == yb).sum().item()
        n += yb.size(0)

    print(f"Epoch {ep:02d} | loss {total_loss/len(dl):.4f} | acc {correct/n:.3f} | {time.time()-t0:.1f}s")

save_path = "/content/drive/MyDrive/cbam_irmas_features.pth"
torch.save(model.features.state_dict(), save_path)
print("Saved CBAM features to:", save_path)


Epoch 01 | loss 2.1438 | acc 0.202 | 130.4s
Epoch 02 | loss 2.0132 | acc 0.264 | 132.4s
Epoch 03 | loss 1.9320 | acc 0.296 | 133.0s
Epoch 04 | loss 1.8600 | acc 0.329 | 130.7s
Epoch 05 | loss 1.7984 | acc 0.350 | 134.3s
Epoch 06 | loss 1.7670 | acc 0.354 | 130.6s
Epoch 07 | loss 1.7159 | acc 0.376 | 132.0s
Epoch 08 | loss 1.6813 | acc 0.387 | 131.1s
Epoch 09 | loss 1.6435 | acc 0.404 | 131.7s
Epoch 10 | loss 1.6238 | acc 0.406 | 130.6s
Epoch 11 | loss 1.6061 | acc 0.417 | 134.7s
Epoch 12 | loss 1.5676 | acc 0.434 | 130.1s
Epoch 13 | loss 1.5405 | acc 0.443 | 134.1s
Epoch 14 | loss 1.5249 | acc 0.446 | 133.1s
Epoch 15 | loss 1.5106 | acc 0.453 | 129.8s
Saved CBAM features to: /content/drive/MyDrive/cbam_irmas_features.pth
