In [1]:
import os
import json
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from tqdm import tqdm
import torchaudio
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModel, AutoProcessor, ClapModel

In [62]:
# ============================================================
# 1. –ö–æ–Ω—Ñ–∏–≥
# ============================================================

@dataclass
class Config:
    train_json: str = "final_dataset/train.json"
    val_json: str   = "final_dataset/val.json"
    test_json: str  = "final_dataset/test.json"

    text_model_name: str = "BAAI/bge-m3"
    projection_dim: int = 2048
    dropout: float = 0.3
    biencoder_ckpt: str = "25_ep_tag_hard_negs_bge_8192/bi_encoder_best.pth"

    clap_ckpt_dir: str = "clap"  # dir —Å save_pretrained() –∏–∑ Jamendo‚Äë—Å–∫—Ä–∏–ø—Ç–∞

    audio_sr: int = 48000
    max_audio_seconds: int = 30

    max_desc_len: int = 4096
    max_lyrics_len: int = 4096

    batch_size: int = 8
    num_workers: int = 0
    epochs: int = 5
    lr: float = 1e-5
    weight_decay: float = 1e-2
    fused_dim: int = 256
    temperature: float = 0.07

    out_dir: str = "fusion_ckpts"
    val_log_path: str = "fusion_val_losses.json"


cfg = Config()
os.makedirs(cfg.out_dir, exist_ok=True)
AUDIO_ROOT = os.path.expanduser('~/persistent_volume/final_dataset/audio')  # –∏–ª–∏ –ø—É—Ç—å –≤ –∫–æ–ª–∞–±–µ/–∫–ª–∞—Å—Ç–µ—Ä–µ
PROJECT_DIR = os.path.expanduser('~/persistent_volume')
OUT_TRAIN_FEATS = os.path.join(PROJECT_DIR, 'final_dataset', 'train_features.pt')
OUT_VAL_FEATS = os.path.join(PROJECT_DIR, 'final_dataset', 'val_features.pt')
OUT_TEST_FEATS = os.path.join(PROJECT_DIR, 'final_dataset', 'test_features.pt')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cuda


In [63]:
torch.cuda.empty_cache()

In [43]:
import os, json, torch
from tqdm import tqdm

from transformers import AutoTokenizer, AutoModel, AutoProcessor, ClapModel
import torchaudio
import torch.nn as nn
import torch.nn.functional as F

# === –ö–æ–Ω—Ñ–∏–≥ –∏ –ø—É—Ç–∏ ===

PROJECT_DIR = os.path.expanduser('~/persistent_volume')
DATA_DIR = os.path.join(PROJECT_DIR, 'final_dataset')
AUDIO_ROOT = os.path.join(DATA_DIR, 'audio')

TEXT_MODEL_NAME = "BAAI/bge-m3"
BIENCODER_CKPT = "25_ep_tag_hard_negs_bge_8192/bi_encoder_best.pth"
CLAP_CKPT_DIR = "clap"

BATCH_SIZE = 8
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device:", device)

# === BiEncoder ===

class BiEncoder(nn.Module):
    def __init__(self, backbone, projection_dim, p_drop):
        super().__init__()
        self.backbone = backbone
        emb_dim = self.backbone.config.hidden_size

        def head():
            return nn.Sequential(
                nn.Linear(emb_dim, projection_dim),
                nn.ReLU(inplace=True),
                nn.Dropout(p_drop),
                nn.Linear(projection_dim, projection_dim),
                nn.Dropout(p_drop),
            )

        self.desc_head = head()
        self.lyr_head = head()

    def _mean_pool(self, outputs, attention_mask):
        hs = outputs.last_hidden_state.float()
        mask = attention_mask.unsqueeze(-1).float()
        denom = mask.sum(1).clamp_min(1e-6)
        pooled = (hs * mask).sum(1) / denom
        return pooled

    def encode_description(self, ids, mask):
        out = self.backbone(input_ids=ids, attention_mask=mask)
        proj = self.desc_head(self._mean_pool(out, mask))
        return F.normalize(proj, p=2, dim=1)

    def encode_lyrics(self, ids, mask):
        out = self.backbone(input_ids=ids, attention_mask=mask)
        proj = self.lyr_head(self._mean_pool(out, mask))
        return F.normalize(proj, p=2, dim=1)

print("Loading BiEncoder & CLAP...")

text_tok = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
text_backbone = AutoModel.from_pretrained(TEXT_MODEL_NAME, trust_remote_code=True).to(device).eval()
biencoder = BiEncoder(text_backbone, projection_dim=2048, p_drop=0.3).to(device)
biencoder.load_state_dict(torch.load(BIENCODER_CKPT, map_location='cpu'))
biencoder.eval()

clap_processor = AutoProcessor.from_pretrained(CLAP_CKPT_DIR)
clap_model = ClapModel.from_pretrained(CLAP_CKPT_DIR).to(device).eval()
clap_sr = clap_processor.feature_extractor.sampling_rate

print("‚úì Models loaded")

@torch.no_grad()
def process_split(split_name: str):
    json_path = os.path.join(DATA_DIR, f"{split_name}.json")
    out_feats_path = os.path.join(DATA_DIR, f"{split_name}_features.pt")

    with open(json_path, 'r', encoding='utf-8') as f:
        data = json.load(f)

    print(f"\n--- Processing split: {split_name} ---")
    print(f"Items in {json_path}: {len(data)}")

    all_feats = {}

    def process_batch(batch_items):
        nonlocal all_feats

        wavs = []
        tids = []
        lyrics_list = []
        desc_list = []

        # 1) —á–∏—Ç–∞–µ–º –∞—É–¥–∏–æ –∏ —Å–æ–±–∏—Ä–∞–µ–º —Ç–µ–∫—Å—Ç—ã —Ç–æ–ª—å–∫–æ –¥–ª—è —É—Å–ø–µ—à–Ω—ã—Ö —Ç—Ä–µ–∫–æ–≤
        for item in batch_items:
            tid = item['track_id']
            path = os.path.join(AUDIO_ROOT, f"{tid}.mp3")
            try:
                wav, sr = torchaudio.load(path)
            except Exception as e:
                print(f"  Skipping {tid}: failed to load audio ({e})")
                continue

            if sr != clap_sr:
                wav = torchaudio.functional.resample(wav, sr, clap_sr)
            if wav.size(0) > 1:
                wav = wav.mean(0, keepdim=True)

            wavs.append(wav.squeeze(0).cpu().numpy())
            tids.append(tid)

            lyr = item.get('lyrics', "")
            desc = item.get('description', "")
            if not isinstance(lyr, str):
                lyr = ""
            if not isinstance(desc, str):
                desc = ""
            lyrics_list.append(lyr)
            desc_list.append(desc)

        if not wavs:
            return

        # 2) CLAP audio
        audio_inputs = clap_processor.feature_extractor(
            raw_speech=wavs,
            sampling_rate=clap_sr,
            return_tensors="pt",
            padding=True,
        )
        audio_inputs = {k: v.to(device) for k, v in audio_inputs.items()}
        audio_emb = clap_model.get_audio_features(
            input_features=audio_inputs['input_features']
        )

        # 3) —Ç–µ–∫—Å—Ç—ã –±–∞—Ç—á–µ–º (BiEncoder + CLAP text)
        try:
            lyr_enc = text_tok(
                lyrics_list, truncation=True, padding=True,
                max_length=1024, return_tensors='pt'
            ).to(device)
            desc_enc = text_tok(
                desc_list, truncation=True, padding=True,
                max_length=512, return_tensors='pt'
            ).to(device)

            lyr_emb = biencoder.encode_lyrics(
                lyr_enc['input_ids'], lyr_enc['attention_mask']
            )
            desc_txt_emb = biencoder.encode_description(
                desc_enc['input_ids'], desc_enc['attention_mask']
            )

            clap_txt = clap_processor(
                text=desc_list, return_tensors='pt', padding=True, truncation=True
            )
            clap_txt = {k: v.to(device) for k, v in clap_txt.items()}
            desc_audio_emb = clap_model.get_text_features(
                input_ids=clap_txt['input_ids'],
                attention_mask=clap_txt['attention_mask'],
            )
        except RuntimeError as e:
            if "CUDA out of memory" in str(e) or "CUDACachingAllocator" in str(e):
                print("  Skipping batch due to OOM:", e)
                torch.cuda.empty_cache()
                return
            else:
                raise

        B = len(tids)
        assert audio_emb.size(0) == B
        assert lyr_emb.size(0) == B
        assert desc_txt_emb.size(0) == B
        assert desc_audio_emb.size(0) == B

        for i, tid in enumerate(tids):
            all_feats[tid] = {
                'audio_emb':      audio_emb[i].cpu(),
                'lyrics_emb':     lyr_emb[i].cpu(),
                'desc_audio_emb': desc_audio_emb[i].cpu(),
                'desc_text_emb':  desc_txt_emb[i].cpu(),
            }

    batch = []
    for item in tqdm(data, desc=f"Precomputing {split_name}"):
        batch.append({
            'track_id':    item['track_id'],
            'lyrics':      item.get('lyrics', ""),
            'description': item.get('description', ""),
        })
        if len(batch) == BATCH_SIZE:
            process_batch(batch)
            batch = []

    if batch:
        process_batch(batch)

    print(f"Computed features for {len(all_feats)} tracks (from {len(data)})")
    torch.save(all_feats, out_feats_path)
    print(f"Saved features to {out_feats_path}")

# –∑–∞–ø—É—Å–∫–∞–µ–º –ø–æ –æ—á–µ—Ä–µ–¥–∏
for split in ["train", "val", "test"]:
    process_split(split)

Total items in /home/jovyan/persistent_volume/final_dataset/test.json: 4096


Precomputing features:   8%|‚ñä         | 336/4096 [06:48<1:15:42,  1.21s/it]

  Skipping sKKgabkwl44SoTR7: failed to load audio (Failed to open the input "/home/jovyan/persistent_volume/final_dataset/audio/sKKgabkwl44SoTR7.mp3" (Invalid argument).)


Precomputing features:  60%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ    | 2448/4096 [49:31<34:43,  1.26s/it]  

  Skipping z3m1XXTGReQx4Jdd: failed to load audio (Failed to open the input "/home/jovyan/persistent_volume/final_dataset/audio/z3m1XXTGReQx4Jdd.mp3" (Invalid argument).)


Precomputing features:  63%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé   | 2592/4096 [52:24<31:43,  1.27s/it]

  Skipping wRRgs5XpbKKIe2KZ: failed to load audio (Failed to open the input "/home/jovyan/persistent_volume/final_dataset/audio/wRRgs5XpbKKIe2KZ.mp3" (Invalid argument).)


Precomputing features: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4096/4096 [1:24:08<00:00,  1.23s/it]

Computed features for 1363 tracks (from 4096)





In [64]:
# ============================================================
# 2. BiEncoder
# ============================================================

class BiEncoder(nn.Module):
    def __init__(self, backbone, projection_dim, p_drop):
        super().__init__()
        self.backbone = backbone
        emb_dim = self.backbone.config.hidden_size

        def head():
            return nn.Sequential(
                nn.Linear(emb_dim, projection_dim),
                nn.ReLU(inplace=True),
                nn.Dropout(p_drop),
                nn.Linear(projection_dim, projection_dim),
                nn.Dropout(p_drop),
            )

        self.desc_head = head()
        self.lyr_head = head()

    def _mean_pool(self, outputs, attention_mask):
        # –∫–∞–∫ —É —Ç–µ–±—è: –±–µ–∑ autocast, –≤ float32 [file:22]
        hs = outputs.last_hidden_state.float()
        mask = attention_mask.unsqueeze(-1).float()
        denom = mask.sum(1).clamp_min(1e-6)
        pooled = (hs * mask).sum(1) / denom
        return pooled

    def encode_description(self, ids, mask):
        out = self.backbone(input_ids=ids, attention_mask=mask)
        proj = self.desc_head(self._mean_pool(out, mask))
        return F.normalize(proj, p=2, dim=1)

    def encode_lyrics(self, ids, mask):
        out = self.backbone(input_ids=ids, attention_mask=mask)
        proj = self.lyr_head(self._mean_pool(out, mask))
        return F.normalize(proj, p=2, dim=1)


In [65]:
print("Loading BiEncoder backbone/tokenizer...")
text_tokenizer = AutoTokenizer.from_pretrained(cfg.text_model_name)
text_backbone = AutoModel.from_pretrained(cfg.text_model_name, trust_remote_code=True)
text_backbone.gradient_checkpointing_enable()
biencoder = BiEncoder(text_backbone, cfg.projection_dim, cfg.dropout).to(device)

# –≥—Ä—É–∑–∏–º —Ç–æ–ª—å–∫–æ –≥–æ–ª–æ–≤—ã, backbone –∑–∞–º–æ—Ä–æ–∑–∏–º
state = torch.load(cfg.biencoder_ckpt, map_location="cpu")
biencoder.load_state_dict(state)
for p in biencoder.backbone.parameters():
    p.requires_grad = False
biencoder.eval()
print("‚úì BiEncoder loaded from", cfg.biencoder_ckpt)

Loading BiEncoder backbone/tokenizer...
‚úì BiEncoder loaded from 25_ep_tag_hard_negs_bge_8192/bi_encoder_best.pth


In [66]:
# ============================================================
# 3. CLAP
# ============================================================

print("Loading CLAP model & processor...")
clap_processor = AutoProcessor.from_pretrained(cfg.clap_ckpt_dir)
clap_model = ClapModel.from_pretrained(cfg.clap_ckpt_dir).to(device)
clap_model.eval()
for p in clap_model.parameters():
    p.requires_grad = False
print("‚úì CLAP loaded from", cfg.clap_ckpt_dir)

Loading CLAP model & processor...
‚úì CLAP loaded from clap


In [67]:
# ============================================================
# 4. Fusion encoder (concat + MLP)
# ============================================================

D_AUDIO = clap_model.config.projection_dim  # –æ–±—ã—á–Ω–æ 512 [file:139]
D_TEXT = cfg.projection_dim                # 2048 –∏–∑ BiEncoder

class FusionEncoder(nn.Module):
    def __init__(self, dim_audio, dim_text, fused_dim=512, hidden=1024, p_drop=0.3):
        super().__init__()
        in_dim = dim_audio + dim_text
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.ReLU(inplace=True),
            nn.Dropout(p_drop),
            nn.Linear(hidden, hidden),
            nn.ReLU(inplace=True),
            nn.Dropout(p_drop),
            nn.Linear(hidden, fused_dim),
        )
        self.norm = nn.LayerNorm(fused_dim)

    def forward(self, audio_emb, text_emb):
        x = torch.cat([audio_emb, text_emb], dim=-1)
        x = self.mlp(x)
        x = self.norm(x)
        return F.normalize(x, p=2, dim=-1)

fusion_encoder = FusionEncoder(D_AUDIO, D_TEXT,
                               fused_dim=256,
                               hidden=1024,
                               p_drop=0.2).to(device)


In [68]:
# ============================================================
# 5. Dataset: audio + lyrics + description
# ============================================================

import os, json, torch
from torch.utils.data import Dataset

class FusionDataset(Dataset):
    def __init__(self, json_path, feats_path):
        with open(json_path, "r", encoding="utf-8") as f:
            items = json.load(f)
        feats = torch.load(feats_path)

        self.pairs = []
        for it in items:
            tid = it["track_id"]
            if tid not in feats:
                continue  # —Ç—Ä–µ–∫–∏ –±–µ–∑ —Ñ–∏—á –≤—ã–∫–∏–¥—ã–≤–∞–µ–º
            self.pairs.append({
                "track_id": tid,
                "description": it["description"],
            })
        self.feats = feats

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        pair = self.pairs[idx]
        tid = pair["track_id"]
        f = self.feats[tid]
        return {
            "track_id":      tid,
            "audio_emb":      f["audio_emb"],      # —Ç—Ä–µ–∫–æ–≤—ã–µ —Ñ–∏—á–∏
            "lyrics_emb":     f["lyrics_emb"],
            "desc_audio_emb": f["desc_audio_emb"], # CLAP‚Äë—Ç–µ–∫—Å—Ç –æ–ø–∏—Å–∞–Ω–∏—è
            "desc_text_emb":  f["desc_text_emb"],  # BiEncoder‚Äë—Ç–µ–∫—Å—Ç –æ–ø–∏—Å–∞–Ω–∏—è
        }

def collate_fn(batch):
    return {
        "track_id":      [b["track_id"] for b in batch],
        "audio_emb":      torch.stack([b["audio_emb"] for b in batch]),
        "lyrics_emb":     torch.stack([b["lyrics_emb"] for b in batch]),
        "desc_audio_emb": torch.stack([b["desc_audio_emb"] for b in batch]),
        "desc_text_emb":  torch.stack([b["desc_text_emb"] for b in batch]),
    }

In [69]:
train_ds = FusionDataset(cfg.train_json, feats_path=OUT_TRAIN_FEATS)
val_ds   = FusionDataset(cfg.val_json, feats_path=OUT_VAL_FEATS)
test_ds  = FusionDataset(cfg.test_json, feats_path=OUT_TEST_FEATS)

train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True,
                          num_workers=cfg.num_workers, collate_fn=collate_fn)
val_loader   = DataLoader(val_ds,   batch_size=cfg.batch_size, shuffle=False,
                          num_workers=cfg.num_workers, collate_fn=collate_fn)
test_loader  = DataLoader(test_ds,  batch_size=cfg.batch_size, shuffle=False,
                          num_workers=cfg.num_workers, collate_fn=collate_fn)

print(f"Train: {len(train_ds)}, Val: {len(val_ds)}, Test: {len(test_ds)}")

Train: 32749, Val: 4093, Test: 4093


In [70]:
# ============================================================
# 6. Loss –∏ –æ–ø—Ç–∏–º–∏–∑–∞—Ç–æ—Ä
# ============================================================

def clip_loss(q, t, temperature):
    logits = (q @ t.t()) / temperature
    labels = torch.arange(logits.size(0), device=logits.device)
    loss_i2t = F.cross_entropy(logits, labels)
    loss_t2i = F.cross_entropy(logits.t(), labels)
    return (loss_i2t + loss_t2i) / 2


optimizer = AdamW(fusion_encoder.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
scheduler = OneCycleLR(optimizer, max_lr=cfg.lr,
                       epochs=cfg.epochs, steps_per_epoch=len(train_loader))
scaler = torch.cuda.amp.GradScaler(enabled=(device == "cuda"))


In [71]:
train_losses, val_losses, lrs = [], [], []
best_val_loss = float("inf")
best_path = os.path.join(cfg.out_dir, "fusion_best.pth")

In [72]:
# ============================================================
# 7. –í—Å–ø–æ–º–æ–≥–∞—Ç–µ–ª—å–Ω—ã–µ —Ñ—É–Ω–∫—Ü–∏–∏: encode –∏ metrics
# ============================================================

def encode_batch(batch):
    audio = batch["audio"].to(device)
    lyr_ids = batch["lyrics_input_ids"].to(device)
    lyr_mask = batch["lyrics_attention_mask"].to(device)
    desc_ids = batch["desc_input_ids"].to(device)
    desc_mask = batch["desc_attention_mask"].to(device)
    clap_ids = batch["clap_input_ids"].to(device)
    clap_mask = batch["clap_attention_mask"]
    if clap_mask is not None:
        clap_mask = clap_mask.to(device)

    # –∑–∞–º–æ—Ä–æ–∂–µ–Ω–Ω—ã–µ —ç–Ω–∫–æ–¥–µ—Ä—ã —Å—á–∏—Ç–∞–µ–º –ø–æ–¥ no_grad
    with torch.no_grad():
        audio_inputs = clap_processor.feature_extractor(
            raw_speech=audio.squeeze(1).cpu().numpy(),
            sampling_rate=clap_processor.feature_extractor.sampling_rate,
            return_tensors="pt",
            padding=True,
        )
        audio_feat = clap_model.get_audio_features(
            input_features=audio_inputs["input_features"].to(device)
        )                           # (B, D_A)
        audio_feat = F.normalize(audio_feat, p=2, dim=-1)

        lyr_emb = biencoder.encode_lyrics(lyr_ids, lyr_mask)          # (B, D_T)
        desc_emb = biencoder.encode_description(desc_ids, desc_mask)  # (B, D_T)

        clap_text_feat = clap_model.get_text_features(
            input_ids=clap_ids,
            attention_mask=clap_mask,
        )                           # (B, D_A)
        clap_text_feat = F.normalize(clap_text_feat, p=2, dim=-1)

    # fusion_encoder –±–µ–∑ no_grad
    track_fused = fusion_encoder(audio_feat, lyr_emb)        # (B, D_fused)
    query_fused = fusion_encoder(clap_text_feat, desc_emb)   # (B, D_fused)

    return query_fused, track_fused


@torch.no_grad()
def evaluate_val():
    fusion_encoder.eval()
    total_loss = 0.0

    for batch in tqdm(val_loader, desc="Validating"):
        q = fusion_encoder(batch['desc_audio_emb'].to(device),
                           batch['desc_text_emb'].to(device))
        t = fusion_encoder(batch['audio_emb'].to(device),
                           batch['lyrics_emb'].to(device))
        loss = clip_loss(q, t, cfg.temperature)
        total_loss += loss.item()

    return total_loss / len(val_loader)


@torch.no_grad()
def compute_embeddings(loader, model):
    model.eval()
    all_q, all_t = [], []

    for batch in tqdm(loader, desc="Computing embeddings"):
        q = model(batch['desc_audio_emb'].to(device),
                           batch['desc_text_emb'].to(device))
        t = model(batch['audio_emb'].to(device),
                           batch['lyrics_emb'].to(device))
        all_q.append(q.cpu())
        all_t.append(t.cpu())

    q_emb = torch.cat(all_q, dim=0)  # (N, D)
    t_emb = torch.cat(all_t, dim=0)  # (N, D)
    return q_emb, t_emb


@torch.no_grad()
def compute_retrieval_metrics(q_emb, t_emb, k_values=(1, 5, 10, 20)):
    sim = torch.matmul(q_emb, t_emb.t())  # (N, N)
    num_queries = sim.size(0)

    recall_at_k = {k: 0 for k in k_values}
    precision_at_k = {k: 0 for k in k_values}
    mrr = 0.0

    for i in range(num_queries):
        ranking = torch.argsort(sim[i], descending=True)
        correct_idx = i
        position = (ranking == correct_idx).nonzero(as_tuple=True)[0].item()
        rank = position + 1
        mrr += 1.0 / rank

        for k in k_values:
            if rank <= k:
                recall_at_k[k] += 1
                precision_at_k[k] += 1.0 / k

    mrr /= num_queries
    for k in k_values:
        recall_at_k[k] /= num_queries
        precision_at_k[k] /= num_queries

    return mrr, recall_at_k, precision_at_k

In [73]:
from IPython.display import clear_output

# ============================================================
# 8. –¢—Ä–µ–Ω–∏—Ä–æ–≤–∫–∞ —Å –≤–∞–ª–∏–¥–∞—Ü–∏–µ–π –∏ –≥—Ä–∞—Ñ–∏–∫–æ–º
# ============================================================

print("\n" + "="*70)
print("START FUSION TRAINING")
print("="*70)

train_losses = []
val_losses = []
lrs = []

step_losses = []        # –ª–æ—Å—Å –Ω–∞ –∫–∞–∂–¥–æ–º optimizer step
global_step = 0

for epoch in range(cfg.epochs):
    print(f"\nEpoch {epoch+1}/{cfg.epochs}")
    fusion_encoder.train()
    total_loss = 0.0

    for batch_idx, batch in enumerate(tqdm(train_loader, desc="Training")):
        optimizer.zero_grad(set_to_none=True)

        with torch.cuda.amp.autocast(enabled=(device == "cuda"),
                                     dtype=torch.bfloat16):
            q = fusion_encoder(batch['desc_audio_emb'].to(device),
                               batch['desc_text_emb'].to(device))
            t = fusion_encoder(batch['audio_emb'].to(device),
                               batch['lyrics_emb'].to(device))
            loss = clip_loss(q, t, cfg.temperature)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        total_loss += loss.item()
        step_losses.append(loss.item())
        global_step += 1

        # # –∫–∞–∂–¥—ã–µ 50 —à–∞–≥–æ–≤ ‚Äî –æ–±–Ω–æ–≤–ª—è–µ–º –≥—Ä–∞—Ñ–∏–∫ train loss
        # if global_step % 50 == 0:
        #     clear_output(wait=True)
        #     plt.figure(figsize=(10, 4))
        #     plt.plot(step_losses, label="Train loss per step")
        #     plt.xlabel("Optimizer step")
        #     plt.ylabel("Loss")
        #     plt.title(f"Train loss (epoch {epoch+1}, step {global_step})")
        #     plt.grid(alpha=0.3)
        #     plt.legend()
        #     plt.show()

    train_loss = total_loss / len(train_loader)
    train_losses.append(train_loss)
    lrs.append(optimizer.param_groups[0]["lr"])
    print(f"Train loss: {train_loss:.4f}")

    # Val
    val_loss = evaluate_val()
    val_losses.append(val_loss)
    print(f"Val loss: {val_loss:.4f}")

    # Save best
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(fusion_encoder.state_dict(), best_path)
        print("üèÜ New best fusion model saved")

# –ì—Ä–∞—Ñ–∏–∫–∏ –ø–æ —ç–ø–æ—Ö–∞–º (PNG)
plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
plt.plot(train_losses, label="Train")
plt.plot(val_losses, label="Val")
plt.xlabel("Epoch"); plt.ylabel("Loss")
plt.title("Fusion Train/Val Loss")
plt.grid(alpha=0.3); plt.legend()

plt.subplot(1,2,2)
plt.plot(lrs, label="LR")
plt.yscale("log")
plt.xlabel("Epoch"); plt.ylabel("LR")
plt.title("Learning rate")
plt.grid(alpha=0.3); plt.legend()

plt.tight_layout()
plt.savefig(os.path.join(cfg.out_dir, "fusion_losses.png"))
print("‚úì Saved loss curves to fusion_losses.png")


START FUSION TRAINING

Epoch 1/5


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4094/4094 [00:17<00:00, 239.51it/s]


Train loss: 1.3915


Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 512/512 [00:00<00:00, 1339.61it/s]


Val loss: 0.6713
üèÜ New best fusion model saved

Epoch 2/5


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4094/4094 [00:12<00:00, 319.68it/s]


Train loss: 0.2749


Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 512/512 [00:00<00:00, 1351.94it/s]


Val loss: 0.5664
üèÜ New best fusion model saved

Epoch 3/5


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4094/4094 [00:12<00:00, 322.27it/s]


Train loss: 0.1322


Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 512/512 [00:00<00:00, 1358.62it/s]


Val loss: 0.5436
üèÜ New best fusion model saved

Epoch 4/5


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4094/4094 [00:12<00:00, 322.57it/s]


Train loss: 0.0911


Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 512/512 [00:00<00:00, 1362.49it/s]


Val loss: 0.5330
üèÜ New best fusion model saved

Epoch 5/5


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4094/4094 [00:12<00:00, 322.54it/s]


Train loss: 0.0781


Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 512/512 [00:00<00:00, 1346.46it/s]


Val loss: 0.5359
‚úì Saved loss curves to fusion_losses.png


In [74]:
# ============================================================
# 9. –û—Ü–µ–Ω–∫–∞ –Ω–∞ test: MRR, Recall@K, Precision@K
# ============================================================

print("\n" + "="*70)
print("EVALUATING ON TEST SET")
print("="*70)

fusion_encoder.load_state_dict(torch.load(best_path, map_location=device))
fusion_encoder.to(device).eval()
print("‚úì Loaded best fusion model")

q_emb, t_emb = compute_embeddings(test_loader, fusion_encoder)
print("Embeddings:", q_emb.shape, t_emb.shape)

mrr, recall_at_k, precision_at_k = compute_retrieval_metrics(q_emb, t_emb)

print("\n" + "="*70)
print("FUSION RETRIEVAL METRICS (Test)")
print("="*70)
print(f"MRR: {mrr:.4f}")
for k in [1, 5, 10, 20]:
    print(f"Recall@{k}:    {recall_at_k[k]:.4f}")
    print(f"Precision@{k}: {precision_at_k[k]:.4f}")


EVALUATING ON TEST SET
‚úì Loaded best fusion model


Computing embeddings: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 512/512 [00:00<00:00, 1530.16it/s]


Embeddings: torch.Size([4093, 256]) torch.Size([4093, 256])

FUSION RETRIEVAL METRICS (Test)
MRR: 0.1651
Recall@1:    0.0760
Precision@1: 0.0760
Recall@5:    0.2516
Precision@5: 0.0503
Recall@10:    0.3609
Precision@10: 0.0361
Recall@20:    0.4698
Precision@20: 0.0235


In [75]:
import os, torch

ROOT = "/home/jovyan/persistent_volume/final_dataset"

# –≥—Ä—É–∑–∏–º —Ñ–∏—á–∏ —Ç—Ä–µ–∫–æ–≤ –∏–∑ —Ç—Ä—ë—Ö —Å–ø–ª–∏—Ç–æ–≤
train_feats = torch.load(os.path.join(ROOT, "train_features.pt"))
val_feats   = torch.load(os.path.join(ROOT, "val_features.pt"))
test_feats  = torch.load(os.path.join(ROOT, "test_features.pt"))

# –æ–±—ä–µ–¥–∏–Ω—è–µ–º –≤ –æ–¥–∏–Ω —Å–ª–æ–≤–∞—Ä—å track_id -> —Ñ–∏—á–∏
all_feats = {}
all_feats.update(train_feats)
all_feats.update(val_feats)
all_feats.update(test_feats)

all_track_ids = sorted(all_feats.keys())  # —Ñ–∏–∫—Å–∏—Ä—É–µ–º –ø–æ—Ä—è–¥–æ–∫

# —Å–æ–±–∏—Ä–∞–µ–º —Ç–µ–Ω–∑–æ—Ä—ã —ç–º–±–µ–¥–¥–∏–Ω–≥–æ–≤ —Ç—Ä–µ–∫–æ–≤
audio_embs  = torch.stack([all_feats[tid]["audio_emb"]  for tid in all_track_ids])
lyrics_embs = torch.stack([all_feats[tid]["lyrics_emb"] for tid in all_track_ids])

# –ø—Ä–æ–≥–æ–Ω—è–µ–º —á–µ—Ä–µ–∑ fusion_encoder, —á—Ç–æ–±—ã –ø–æ–ª—É—á–∏—Ç—å fused —Ç—Ä–µ–∫–æ–≤—ã–µ —ç–º–±–µ–¥–¥–∏–Ω–≥–∏
fusion_encoder.to(device).eval()
with torch.no_grad():
    t_emb_all = fusion_encoder(
        audio_embs.to(device),
        lyrics_embs.to(device),
    ).cpu()   # (N_tracks, D)

print(t_emb_all.shape, len(all_track_ids))


torch.Size([2835, 256]) 2835


In [76]:
from tqdm import tqdm

fusion_encoder.eval()
all_q_test = []
test_track_ids = []   # ¬´–ø—Ä–∞–≤–∏–ª—å–Ω—ã–π¬ª —Ç—Ä–µ–∫ –¥–ª—è –∫–∞–∂–¥–æ–≥–æ –æ–ø–∏—Å–∞–Ω–∏—è

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Fuse test queries"):
        q = fusion_encoder(
            batch["desc_audio_emb"].to(device),
            batch["desc_text_emb"].to(device),
        )
        all_q_test.append(q.cpu())

        # –±–µ—Ä–µ–º track_id –∏–∑ –∏—Å—Ö–æ–¥–Ω–æ–≥–æ —Ç–µ—Å—Ç–æ–≤–æ–≥–æ json, –µ—Å–ª–∏ –≤ –¥–∞—Ç–∞—Å–µ—Ç–µ –æ–Ω –µ—Å—Ç—å
        test_track_ids.extend(batch["track_id"])  # –Ω—É–∂–Ω–æ –¥–æ–±–∞–≤–∏—Ç—å –≤ __getitem__

q_emb_test = torch.cat(all_q_test, dim=0)  # (N_queries, D)
print(q_emb_test.shape, len(test_track_ids))


Fuse test queries: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 512/512 [00:00<00:00, 2579.36it/s]


torch.Size([4093, 256]) 4093


In [77]:
import faiss
import numpy as np

t_all_np = t_emb_all.numpy().astype("float32")
q_test_np = q_emb_test.numpy().astype("float32")

D = t_all_np.shape[1]
index = faiss.IndexFlatL2(D)
index.add(t_all_np)

K = 50
distances, indices = index.search(q_test_np, K)

track_id_array = np.array(all_track_ids)

def compute_faiss_metrics(indices, track_id_array, test_track_ids, ks=(1, 5, 10)):
    N, K = indices.shape
    ks = sorted(ks)
    max_k = ks[-1]

    mrr = 0.0
    recall_at = {k: 0 for k in ks}
    precision_at = {k: 0.0 for k in ks}

    for i in range(N):
        true_tid = test_track_ids[i]
        retrieved_tids = track_id_array[indices[i, :max_k]]

        hits = np.where(retrieved_tids == true_tid)[0]
        if len(hits) > 0:
            rank = hits[0] + 1
            mrr += 1.0 / rank
            for k in ks:
                if rank <= k:
                    recall_at[k] += 1

        for k in ks:
            topk = retrieved_tids[:k]
            correct = np.sum(topk == true_tid)
            precision_at[k] += correct / k

    mrr /= N
    for k in ks:
        recall_at[k] /= N
        precision_at[k] /= N

    return mrr, recall_at, precision_at

mrr, recall_at_k, precision_at_k = compute_faiss_metrics(
    indices, track_id_array, test_track_ids, ks=(1, 5, 10, 20)
)

print("MRR:", mrr)
for k, v in recall_at_k.items():
    print(f"Recall@{k}:", v)
for k, v in precision_at_k.items():
    print(f"Precision@{k}:", v)


MRR: 0.2823589794327138
Recall@1: 0.2030295626679697
Recall@5: 0.36501343757634985
Recall@10: 0.4490593696555094
Recall@20: 0.535304177864647
Precision@1: 0.2030295626679697
Precision@5: 0.07300268751526792
Precision@10: 0.04490593696554944
Precision@20: 0.026765208893231352
