In [2]:
# PRE-TRAINING ON MULTI-SCALE MODEL

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
!cp -r "/content/drive/MyDrive/IRMAS-TrainingData" /content/

In [5]:
!ls "/content/IRMAS-TrainingData"

cel  cla  flu  gac  gel  org  pia  README.txt  sax  tru  vio  voi


In [6]:
import os, time
import librosa
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [7]:
ROOT_DIR = "/content/IRMAS-TrainingData"
SR = 16000
N_MELS = 128
N_FFT = 1024
HOP_LENGTH = 320
TARGET_FRAMES = 151
BATCH_SIZE = 32
LR = 1e-4
EPOCHS = 15

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

Device: cuda


In [8]:
# Mapping

In [9]:
instrument_to_id = {
    "cel": 0, "cla": 1, "flu": 2, "gac": 3, "gel": 4,
    "pia": 5, "sax": 6, "tru": 7, "vio": 8
}
NUM_CLASSES = len(instrument_to_id)

In [10]:
# IRMASDataset

In [11]:
class IRMASDataset(Dataset):
    def __init__(self, root_dir, instrument_to_id):
        self.files = []
        self.labels = []
        for inst, lab in instrument_to_id.items():
            inst_dir = os.path.join(root_dir, inst)
            if not os.path.isdir(inst_dir):
                raise FileNotFoundError(f"Missing folder: {inst_dir}")
            for f in os.listdir(inst_dir):
                if f.endswith(".wav"):
                    self.files.append(os.path.join(inst_dir, f))
                    self.labels.append(lab)

    def __len__(self):
        return len(self.files)

    def _fix_frames(self, mel, target_frames):
        # mel: [F, T]
        T = mel.shape[1]
        if T == target_frames:
            return mel
        if T > target_frames:
            return mel[:, :target_frames]
        pad = target_frames - T
        # pad sağa (min değerle)
        return F.pad(mel, (0, pad), value=float(mel.min()))

    def __getitem__(self, idx):
        y, _ = librosa.load(self.files[idx], sr=SR, mono=True)
        mel = librosa.feature.melspectrogram(
            y=y, sr=SR, n_mels=N_MELS, n_fft=N_FFT, hop_length=HOP_LENGTH
        )
        mel = librosa.power_to_db(mel)
        mel = torch.tensor(mel, dtype=torch.float32)       # [F,T]
        mel = self._fix_frames(mel, TARGET_FRAMES)         # [F,Tfixed]
        mel = mel.unsqueeze(0)                             # [1,F,T]
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return mel, label

In [12]:
# DataLoader + sanity check

In [13]:
ds = IRMASDataset(ROOT_DIR, instrument_to_id)
dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)

xb, yb = next(iter(dl))
print("Mel:", xb.shape, "Labels:", yb.shape)

Mel: torch.Size([32, 1, 128, 151]) Labels: torch.Size([32])


In [14]:
# Multi-Scale CRNN model

In [15]:
class MultiScaleConv(nn.Module):
    def __init__(self, in_ch=1, out_ch=64):
        super().__init__()
        b = out_ch // 3
        self.b1 = nn.Sequential(nn.Conv2d(in_ch, b, 3, padding=1), nn.ReLU())
        self.b2 = nn.Sequential(nn.Conv2d(in_ch, b, 5, padding=2), nn.ReLU())
        self.b3 = nn.Sequential(nn.Conv2d(in_ch, out_ch - 2*b, 7, padding=3), nn.ReLU())
        self.bn = nn.BatchNorm2d(out_ch)

    def forward(self, x):
        x = torch.cat([self.b1(x), self.b2(x), self.b3(x)], dim=1)
        return self.bn(x)

class MultiScaleCRNN(nn.Module):
    def __init__(self, num_classes, rnn_hidden=128):
        super().__init__()
        self.conv = nn.Sequential(
            MultiScaleConv(1, 64),
            nn.MaxPool2d((2,2)),
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d((2,2))
        )
        # conv çıktısını [B, 128, F', T'] yapar. F' üzerinden average alıp GRU’ya veririz.
        self.rnn = nn.GRU(
            input_size=128, hidden_size=rnn_hidden, num_layers=1,
            batch_first=True, bidirectional=True
        )
        self.classifier = nn.Linear(2*rnn_hidden, num_classes)

    def forward(self, x):
        x = self.conv(x)             # [B,128,F',T']
        x = torch.mean(x, dim=2)     # [B,128,T']
        x = x.transpose(1, 2)        # [B,T',128]
        y, _ = self.rnn(x)           # [B,T',2H]
        y = torch.mean(y, dim=1)     # [B,2H]
        return self.classifier(y)


In [16]:
model = MultiScaleCRNN(NUM_CLASSES).to(DEVICE)
crit = nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(), lr=LR)

for ep in range(1, EPOCHS+1):
    model.train()
    total_loss, correct, n = 0.0, 0, 0
    t0 = time.time()

    for xb, yb in dl:
        xb = xb.to(DEVICE, non_blocking=True)
        yb = yb.to(DEVICE, non_blocking=True)

        opt.zero_grad()
        logits = model(xb)
        loss = crit(logits, yb)
        loss.backward()
        opt.step()

        total_loss += loss.item()
        pred = logits.argmax(dim=1)
        correct += (pred == yb).sum().item()
        n += yb.size(0)

    print(f"Epoch {ep:02d} | loss {total_loss/len(dl):.4f} | acc {correct/n:.3f} | {time.time()-t0:.1f}s")

save_path = "/content/drive/MyDrive/mscrnn_irmas.pth"
torch.save({"conv": model.conv.state_dict(), "rnn": model.rnn.state_dict()}, save_path)
print("Saved MS-CRNN backbone to:", save_path)


Epoch 01 | loss 2.0654 | acc 0.239 | 135.9s
Epoch 02 | loss 1.9846 | acc 0.282 | 132.9s
Epoch 03 | loss 1.9236 | acc 0.298 | 136.4s
Epoch 04 | loss 1.8436 | acc 0.340 | 138.4s
Epoch 05 | loss 1.7578 | acc 0.359 | 135.8s
Epoch 06 | loss 1.6894 | acc 0.389 | 138.7s
Epoch 07 | loss 1.6400 | acc 0.405 | 135.7s
Epoch 08 | loss 1.5931 | acc 0.422 | 135.5s
Epoch 09 | loss 1.5543 | acc 0.430 | 134.2s
Epoch 10 | loss 1.5157 | acc 0.451 | 138.5s
Epoch 11 | loss 1.5002 | acc 0.463 | 136.0s
Epoch 12 | loss 1.4496 | acc 0.474 | 135.4s
Epoch 13 | loss 1.4198 | acc 0.488 | 137.3s
Epoch 14 | loss 1.3959 | acc 0.491 | 134.6s
Epoch 15 | loss 1.3516 | acc 0.513 | 135.8s
Saved MS-CRNN backbone to: /content/drive/MyDrive/mscrnn_irmas.pth


In [17]:
!ls -lh "/content/drive/MyDrive/mscrnn_irmas.pth"

-rw------- 1 root root 1.1M Dec 22 20:56 /content/drive/MyDrive/mscrnn_irmas.pth


In [18]:
import torch
state = torch.load("/content/drive/MyDrive/mscrnn_irmas.pth", map_location="cpu")
print(len(state), list(state.keys())[:2])

2 ['conv', 'rnn']
