In [1]:
import torch, torch.nn.functional as F
import pytorch_lightning as pl

class MusicCNN(pl.LightningModule):
    def __init__(self, lr=1e-3, dropout=.3, n_filters=32):
        super().__init__()
        self.save_hyperparameters()
        self.conv = torch.nn.Sequential(
            torch.nn.Conv2d(1, n_filters, 3, padding=1),  # 128×128
            torch.nn.BatchNorm2d(n_filters), torch.nn.ReLU(), torch.nn.MaxPool2d(2),
            torch.nn.Conv2d(n_filters, n_filters*2, 3, padding=1),  # 64×64
            torch.nn.BatchNorm2d(n_filters*2), torch.nn.ReLU(), torch.nn.MaxPool2d(2)
        )
        self.drop = torch.nn.Dropout(dropout)
        self.fc   = torch.nn.Linear((n_filters*2)*32*32, 10)        # 32×32 flat

    def forward(self, x):
        return self.fc(self.drop(torch.flatten(self.conv(x), 1)))

    def _step(self, batch):
        x, y = batch; logits = self(x)
        loss = F.cross_entropy(logits, y)
        acc  = (logits.argmax(1) == y).float().mean()
        return loss, acc

    def training_step(self, batch, _):  l, a = self._step(batch); self.log_dict({"train_loss": l, "train_acc": a}); return l
    def validation_step(self, batch, _): l, a = self._step(batch); self.log_dict({"val_loss": l, "val_acc": a}, prog_bar=True)
    def configure_optimizers(self):      return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)


In [2]:
CKPT_PATH = "../notebooks/checkpoints/best_cnn.ckpt"
DEVICE    = "cuda" if torch.cuda.is_available() else "cpu"

model = MusicCNN.load_from_checkpoint(CKPT_PATH).to(DEVICE).eval()
print("Model loaded from:", CKPT_PATH)


Model loaded from: ../notebooks/checkpoints/best_cnn.ckpt


In [3]:
import torch, librosa, numpy as np, pathlib, random
from scipy.spatial.distance import cdist
import pandas as pd

N_MELS = 128
TIME_FRAMES = 128     # same as training

def preprocess_audio(path, sr=22_050):
    y, _ = librosa.load(path, sr=sr, mono=True)
    mel  = librosa.feature.melspectrogram(y, sr=sr, n_mels=N_MELS)
    mel  = librosa.power_to_db(mel, ref=np.max)

    # center‑crop / pad to 128 frames
    t = mel.shape[1]
    if t > TIME_FRAMES:
        start = (t - TIME_FRAMES) // 2
        mel = mel[:, start:start + TIME_FRAMES]
    else:
        mel = np.pad(mel, ((0, 0), (0, TIME_FRAMES - t)), "constant")

    return torch.tensor(mel).unsqueeze(0).unsqueeze(0).float()   # (1,1,128,128)


In [4]:
def get_embedding(x):
    with torch.no_grad():
        feats = model.conv(x)           # output shape (B,C,32,32)
        gap   = torch.nn.functional.adaptive_avg_pool2d(feats, 1).squeeze(-1).squeeze(-1)
    return gap.cpu().numpy()

model.get_embedding = get_embedding   # monkey‑patch for quick use


In [5]:
TEST_DIR = pathlib.Path("../datasets/GTZAN")       # adjust if needed
bank_paths = random.sample(
    list(TEST_DIR.rglob("*.au")), k=200)           # 200 tracks for quick demo

emb_bank, genres = [], []
for p in bank_paths:
    x = preprocess_audio(p).to(DEVICE)
    emb_bank.append(model.get_embedding(x))
    genres.append(p.parent.name)

emb_bank = np.vstack(emb_bank)
print("Built embedding bank with", len(bank_paths), "tracks")


  0.01657104] as keyword args. From version 0.10 passing these as positional arguments will result in an error
  mel  = librosa.feature.melspectrogram(y, sr=sr, n_mels=N_MELS)
 -0.17633057] as keyword args. From version 0.10 passing these as positional arguments will result in an error
  mel  = librosa.feature.melspectrogram(y, sr=sr, n_mels=N_MELS)
 -0.00448608] as keyword args. From version 0.10 passing these as positional arguments will result in an error
  mel  = librosa.feature.melspectrogram(y, sr=sr, n_mels=N_MELS)
  mel  = librosa.feature.melspectrogram(y, sr=sr, n_mels=N_MELS)
  mel  = librosa.feature.melspectrogram(y, sr=sr, n_mels=N_MELS)
 -0.07946777] as keyword args. From version 0.10 passing these as positional arguments will result in an error
  mel  = librosa.feature.melspectrogram(y, sr=sr, n_mels=N_MELS)
 -0.08255005] as keyword args. From version 0.10 passing these as positional arguments will result in an error
  mel  = librosa.feature.melspectrogram(y, sr=sr, n_mel

Built embedding bank with 200 tracks


  mel  = librosa.feature.melspectrogram(y, sr=sr, n_mels=N_MELS)
 -0.17047119] as keyword args. From version 0.10 passing these as positional arguments will result in an error
  mel  = librosa.feature.melspectrogram(y, sr=sr, n_mels=N_MELS)


In [6]:
def predict_and_recommend(wav_path, k=5):
    x = preprocess_audio(wav_path).to(DEVICE)
    with torch.no_grad():
        logits = model(x)
        pred   = logits.argmax(1).item()
        idx2genre = sorted({p.parent.name for p in bank_paths})
        genre  = idx2genre[pred]

    seed_emb = model.get_embedding(x)
    dists    = cdist(seed_emb, emb_bank, metric="cosine").flatten()
    idxs     = dists.argsort()[:k]

    recs = [(bank_paths[i].name, genres[i], dists[i]) for i in idxs]
    return genre, recs


In [8]:
sample_clip = "../datasets/GTZAN/disco/disco.00008.au"   # pick any file
genre, recs = predict_and_recommend(sample_clip, k=5)

print("Input clip:", pathlib.Path(sample_clip).name)
print("Predicted genre:", genre, "\n\nTop‑5 similar tracks:")
df = pd.DataFrame(recs, columns=["file", "genre", "cosine_dist"])

display(df)


Input clip: disco.00008.au
Predicted genre: disco 

Top‑5 similar tracks:


  mel  = librosa.feature.melspectrogram(y, sr=sr, n_mels=N_MELS)


Unnamed: 0,file,genre,cosine_dist
0,disco.00012.au,disco,0.002541
1,pop.00019.au,pop,0.002703
2,rock.00075.au,rock,0.00376
3,hiphop.00025.au,hiphop,0.004357
4,country.00041.au,country,0.004979
