In [1]:
import os
import json
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from tqdm import tqdm
import torchaudio

from transformers import AutoTokenizer, AutoModel, AutoProcessor, ClapModel

In [2]:
# ============================================================
# 1. –ö–æ–Ω—Ñ–∏–≥
# ============================================================

@dataclass
class Config:
    train_json: str = "final_dataset/train.json"
    val_json: str   = "final_dataset/val.json"
    test_json: str  = "final_dataset/test.json"

    text_model_name: str = "BAAI/bge-m3"
    projection_dim: int = 2048
    dropout: float = 0.3
    biencoder_ckpt: str = "25_ep_tag_hard_negs_bge_8192/bi_encoder_best.pth"

    clap_ckpt_dir: str = "clap"  # dir —Å save_pretrained() –∏–∑ Jamendo‚Äë—Å–∫—Ä–∏–ø—Ç–∞

    audio_sr: int = 48000
    max_audio_seconds: int = 30

    max_desc_len: int = 4096
    max_lyrics_len: int = 4096

    batch_size: int = 8
    num_workers: int = 0
    epochs: int = 5
    lr: float = 1e-4
    weight_decay: float = 1e-2
    fused_dim: int = 512
    temperature: float = 0.07

    out_dir: str = "fusion_ckpts"
    val_log_path: str = "fusion_val_losses.json"


cfg = Config()
os.makedirs(cfg.out_dir, exist_ok=True)
AUDIO_ROOT = os.path.expanduser('~/persistent_volume/final_dataset/audio/audio')  # –∏–ª–∏ –ø—É—Ç—å –≤ –∫–æ–ª–∞–±–µ/–∫–ª–∞—Å—Ç–µ—Ä–µ
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cuda


In [None]:
import os, json, torch
from tqdm import tqdm

from transformers import AutoTokenizer, AutoModel, AutoProcessor, ClapModel
import torchaudio

# –ø—É—Ç–∏
PROJECT_DIR = os.path.expanduser('~/persistent_volume')
AUDIO_ROOT = os.path.join(PROJECT_DIR, 'final_dataset', 'audio/audio')

DATA_JSON = os.path.join(PROJECT_DIR, 'final_dataset', 'train.json')
OUT_FEATS = os.path.join(PROJECT_DIR, 'final_dataset', 'train_features.pt')

device = 'cuda' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE = 32

# --- –º–æ–¥–µ–ª–∏ ---

text_tok = AutoTokenizer.from_pretrained(cfg.text_model_name)
text_backbone = AutoModel.from_pretrained(cfg.text_model_name, trust_remote_code=True).to(device).eval()
text_backbone.gradient_checkpointing_enable()
biencoder = BiEncoder(text_backbone, cfg.projection_dim, cfg.dropout).to(device)

biencoder.load_state_dict(torch.load(cfg.biencoder_ckpt, map_location=device))
biencoder.to(device).eval()

clap_processor = AutoProcessor.from_pretrained(cfg.clap_ckpt_dir)
clap_model = ClapModel.from_pretrained(cfg.clap_ckpt_dir).to(device).eval()

clap_sr = clap_processor.feature_extractor.sampling_rate

# --- –∑–∞–≥—Ä—É–∑–∫–∞ –¥–∞–Ω–Ω—ã—Ö ---

with open(DATA_JSON, 'r', encoding='utf-8') as f:
    data = json.load(f)

print(f"Total items in {DATA_JSON}: {len(data)}")

all_feats = {}

@torch.no_grad()
def process_batch(batch_items):
    """batch_items: —Å–ø–∏—Å–æ–∫ dict {track_id, lyrics, description}"""
    # 1) —á–∏—Ç–∞–µ–º –∏ –≥–æ—Ç–æ–≤–∏–º –∞—É–¥–∏–æ (—Å–ø–∏—Å–æ–∫ numpy/—Ç–µ–Ω–∑–æ—Ä–æ–≤)
    wavs = []
    tids = []
    for item in batch_items:
        tid = item['track_id']
        path = os.path.join(AUDIO_ROOT, f"{tid}.mp3")
        wav, sr = torchaudio.load(path)
        if sr != clap_sr:
            wav = torchaudio.functional.resample(wav, sr, clap_sr)
        if wav.size(0) > 1:
            wav = wav.mean(0, keepdim=True)
        wavs.append(wav.squeeze(0).cpu().numpy())
        tids.append(tid)

    # 2) CLAP audio —Ñ–∏—á–∏ (–±–∞—Ç—á)
    audio_inputs = clap_processor.feature_extractor(
        raw_speech=wavs,
        sampling_rate=clap_sr,
        return_tensors="pt",
        padding=True,
    )
    audio_inputs = {k: v.to(device) for k, v in audio_inputs.items()}
    audio_emb = clap_model.get_audio_features(
        input_features=audio_inputs['input_features']
    )                      # (B, D_A)

    # 3) —Ç–µ–∫—Å—Ç—ã –±–∞—Ç—á–µ–º
    lyrics_list = [it['lyrics'] for it in batch_items]
    desc_list   = [it['description'] for it in batch_items]

    lyr_enc = text_tok(
        lyrics_list, truncation=True, padding=True,
        max_length=cfg.max_lyrics_len, return_tensors='pt'
    ).to(device)
    desc_enc = text_tok(
        desc_list, truncation=True, padding=True,
        max_length=cfg.max_desc_len, return_tensors='pt'
    ).to(device)

    lyr_emb = biencoder.encode_lyrics(
        lyr_enc['input_ids'], lyr_enc['attention_mask']
    )                     # (B, D_T)
    desc_txt_emb = biencoder.encode_description(
        desc_enc['input_ids'], desc_enc['attention_mask']
    )                     # (B, D_T)

    # 4) CLAP text –¥–ª—è –æ–ø–∏—Å–∞–Ω–∏–π –±–∞—Ç—á–µ–º
    clap_txt = clap_processor(
        text=desc_list, return_tensors='pt', padding=True, truncation=True
    )
    clap_txt = {k: v.to(device) for k, v in clap_txt.items()}
    desc_audio_emb = clap_model.get_text_features(
        input_ids=clap_txt['input_ids'],
        attention_mask=clap_txt['attention_mask'],
    )                     # (B, D_A)

    # 5) —Ä–∞—Å–∫–ª–∞–¥—ã–≤–∞–µ–º –æ–±—Ä–∞—Ç–Ω–æ –ø–æ track_id
    for i, tid in enumerate(tids):
        all_feats[tid] = {
            'audio_emb':      audio_emb[i].cpu(),
            'lyrics_emb':     lyr_emb[i].cpu(),
            'desc_audio_emb': desc_audio_emb[i].cpu(),
            'desc_text_emb':  desc_txt_emb[i].cpu(),
        }

# --- –æ—Å–Ω–æ–≤–Ω–æ–π —Ü–∏–∫–ª –ø–æ –±–∞—Ç—á–∞–º ---

batch = []
for item in tqdm(data, desc="Precomputing features"):
    batch.append({
        'track_id':   item['track_id'],
        'lyrics':     item['lyrics'],
        'description': item['description'],
    })
    if len(batch) == BATCH_SIZE:
        process_batch(batch)
        batch = []

# —Ö–≤–æ—Å—Ç
if batch:
    process_batch(batch)

print(f"Computed features for {len(all_feats)} tracks")

torch.save(all_feats, OUT_FEATS)
print(f"Saved features to {OUT_FEATS}")

Total items in /home/jovyan/persistent_volume/final_dataset/train.json: 32760


Precomputing features:   4%|‚ñé         | 1184/32760 [39:34<17:30:11,  2.00s/it]

In [4]:
# ============================================================
# 2. BiEncoder
# ============================================================

class BiEncoder(nn.Module):
    def __init__(self, backbone, projection_dim, p_drop):
        super().__init__()
        self.backbone = backbone
        emb_dim = self.backbone.config.hidden_size

        def head():
            return nn.Sequential(
                nn.Linear(emb_dim, projection_dim),
                nn.ReLU(inplace=True),
                nn.Dropout(p_drop),
                nn.Linear(projection_dim, projection_dim),
                nn.Dropout(p_drop),
            )

        self.desc_head = head()
        self.lyr_head = head()

    def _mean_pool(self, outputs, attention_mask):
        # –∫–∞–∫ —É —Ç–µ–±—è: –±–µ–∑ autocast, –≤ float32 [file:22]
        hs = outputs.last_hidden_state.float()
        mask = attention_mask.unsqueeze(-1).float()
        denom = mask.sum(1).clamp_min(1e-6)
        pooled = (hs * mask).sum(1) / denom
        return pooled

    def encode_description(self, ids, mask):
        out = self.backbone(input_ids=ids, attention_mask=mask)
        proj = self.desc_head(self._mean_pool(out, mask))
        return F.normalize(proj, p=2, dim=1)

    def encode_lyrics(self, ids, mask):
        out = self.backbone(input_ids=ids, attention_mask=mask)
        proj = self.lyr_head(self._mean_pool(out, mask))
        return F.normalize(proj, p=2, dim=1)


In [None]:
print("Loading BiEncoder backbone/tokenizer...")
text_tokenizer = AutoTokenizer.from_pretrained(cfg.text_model_name)
text_backbone = AutoModel.from_pretrained(cfg.text_model_name, trust_remote_code=True)
text_backbone.gradient_checkpointing_enable()
biencoder = BiEncoder(text_backbone, cfg.projection_dim, cfg.dropout).to(device)

# –≥—Ä—É–∑–∏–º —Ç–æ–ª—å–∫–æ –≥–æ–ª–æ–≤—ã, backbone –∑–∞–º–æ—Ä–æ–∑–∏–º
state = torch.load(cfg.biencoder_ckpt, map_location="cpu")
biencoder.load_state_dict(state)
for p in biencoder.backbone.parameters():
    p.requires_grad = False
biencoder.eval()
print("‚úì BiEncoder loaded from", cfg.biencoder_ckpt)

In [None]:
# ============================================================
# 3. CLAP
# ============================================================

print("Loading CLAP model & processor...")
clap_processor = AutoProcessor.from_pretrained(cfg.clap_ckpt_dir)
clap_model = ClapModel.from_pretrained(cfg.clap_ckpt_dir).to(device)
clap_model.eval()
for p in clap_model.parameters():
    p.requires_grad = False
print("‚úì CLAP loaded from", cfg.clap_ckpt_dir)

In [None]:
# ============================================================
# 4. Fusion encoder (concat + MLP)
# ============================================================

D_AUDIO = clap_model.config.projection_dim  # –æ–±—ã—á–Ω–æ 512 [file:139]
D_TEXT = cfg.projection_dim                # 2048 –∏–∑ BiEncoder

class FusionEncoder(nn.Module):
    def __init__(self, dim_audio, dim_text, fused_dim):
        super().__init__()
        in_dim = dim_audio + dim_text
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, fused_dim),
            nn.ReLU(inplace=True),
            nn.Linear(fused_dim, fused_dim),
        )

    def forward(self, audio_emb, text_emb):
        x = torch.cat([audio_emb, text_emb], dim=-1)
        x = self.mlp(x)
        return F.normalize(x, p=2, dim=-1)

fusion_encoder = FusionEncoder(D_AUDIO, D_TEXT, cfg.fused_dim).to(device)

In [None]:

# ============================================================
# 5. Dataset: audio + lyrics + description
# ============================================================

class FusionFeatDataset(Dataset):
    def __init__(self, json_path, feats_path):
        with open(json_path, 'r', encoding='utf-8') as f:
            self.items = json.load(f)
        self.feats = torch.load(feats_path)  # dict tid -> tensors

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        item = self.items[idx]
        tid = item['track_id']
        f = self.feats[tid]
        return {
            'audio_emb': f['audio_emb'],
            'lyrics_emb': f['lyrics_emb'],
            'desc_audio_emb': f['desc_audio_emb'],
            'desc_text_emb': f['desc_text_emb'],
        }

def collate_fn_feats(batch):
    return {
        'audio_emb': torch.stack([b['audio_emb'] for b in batch]),
        'lyrics_emb': torch.stack([b['lyrics_emb'] for b in batch]),
        'desc_audio_emb': torch.stack([b['desc_audio_emb'] for b in batch]),
        'desc_text_emb': torch.stack([b['desc_text_emb'] for b in batch]),
    }

In [None]:
train_ds = FusionDataset(cfg.train_json)
val_ds   = FusionDataset(cfg.val_json)
test_ds  = FusionDataset(cfg.test_json)

train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True,
                          num_workers=cfg.num_workers, collate_fn=collate_fn)
val_loader   = DataLoader(val_ds,   batch_size=cfg.batch_size, shuffle=False,
                          num_workers=cfg.num_workers, collate_fn=collate_fn)
test_loader  = DataLoader(test_ds,  batch_size=cfg.batch_size, shuffle=False,
                          num_workers=cfg.num_workers, collate_fn=collate_fn)

print(f"Train: {len(train_ds)}, Val: {len(val_ds)}, Test: {len(test_ds)}")

In [None]:
# ============================================================
# 6. Loss –∏ –æ–ø—Ç–∏–º–∏–∑–∞—Ç–æ—Ä
# ============================================================

def clip_loss(q, t, temperature):
    logits = (q @ t.t()) / temperature
    labels = torch.arange(logits.size(0), device=logits.device)
    loss_i2t = F.cross_entropy(logits, labels)
    loss_t2i = F.cross_entropy(logits.t(), labels)
    return (loss_i2t + loss_t2i) / 2


optimizer = AdamW(fusion_encoder.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
scheduler = OneCycleLR(optimizer, max_lr=cfg.lr,
                       epochs=cfg.epochs, steps_per_epoch=len(train_loader))
scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))


In [None]:
train_losses, val_losses, lrs = [], [], []
best_val_loss = float("inf")
best_path = os.path.join(cfg.out_dir, "fusion_best.pth")

In [None]:
# ============================================================
# 7. –í—Å–ø–æ–º–æ–≥–∞—Ç–µ–ª—å–Ω—ã–µ —Ñ—É–Ω–∫—Ü–∏–∏: encode –∏ metrics
# ============================================================

def encode_batch(batch):
    audio = batch["audio"].to(device)
    lyr_ids = batch["lyrics_input_ids"].to(device)
    lyr_mask = batch["lyrics_attention_mask"].to(device)
    desc_ids = batch["desc_input_ids"].to(device)
    desc_mask = batch["desc_attention_mask"].to(device)
    clap_ids = batch["clap_input_ids"].to(device)
    clap_mask = batch["clap_attention_mask"]
    if clap_mask is not None:
        clap_mask = clap_mask.to(device)

    # –∑–∞–º–æ—Ä–æ–∂–µ–Ω–Ω—ã–µ —ç–Ω–∫–æ–¥–µ—Ä—ã —Å—á–∏—Ç–∞–µ–º –ø–æ–¥ no_grad
    with torch.no_grad():
        audio_inputs = clap_processor.feature_extractor(
            raw_speech=audio.squeeze(1).cpu().numpy(),
            sampling_rate=clap_processor.feature_extractor.sampling_rate,
            return_tensors="pt",
            padding=True,
        )
        audio_feat = clap_model.get_audio_features(
            input_features=audio_inputs["input_features"].to(device)
        )                           # (B, D_A)
        audio_feat = F.normalize(audio_feat, p=2, dim=-1)

        lyr_emb = biencoder.encode_lyrics(lyr_ids, lyr_mask)          # (B, D_T)
        desc_emb = biencoder.encode_description(desc_ids, desc_mask)  # (B, D_T)

        clap_text_feat = clap_model.get_text_features(
            input_ids=clap_ids,
            attention_mask=clap_mask,
        )                           # (B, D_A)
        clap_text_feat = F.normalize(clap_text_feat, p=2, dim=-1)

    # fusion_encoder –±–µ–∑ no_grad
    track_fused = fusion_encoder(audio_feat, lyr_emb)        # (B, D_fused)
    query_fused = fusion_encoder(clap_text_feat, desc_emb)   # (B, D_fused)

    return query_fused, track_fused


@torch.no_grad()
def evaluate_val():
    fusion_encoder.eval()
    total_loss = 0.0
    for batch in tqdm(val_loader, desc="Validating"):
        q, t = encode_batch(batch)
        loss = clip_loss(q, t, cfg.temperature)
        total_loss += loss.item()
    avg = total_loss / len(val_loader)
    # –ª–æ–≥ –≤ json –∫–∞–∫ –≤ —Ç–≤–æ—ë–º evaluate() [attached_file:22]
    try:
        with open(cfg.val_log_path, "r") as f:
            logs = json.load(f)
    except Exception:
        logs = []
    logs.append({"epoch": len(logs), "val_loss": avg, "timestamp": str(datetime.now())})
    with open(cfg.val_log_path, "w") as f:
        json.dump(logs, f, indent=2)
    return avg


@torch.no_grad()
def compute_embeddings(loader):
    fusion_encoder.eval()
    all_q, all_t = [], []
    for batch in tqdm(loader, desc="Computing embeddings"):
        q, t = encode_batch(batch)
        all_q.append(q)
        all_t.append(t)
    all_q = torch.cat(all_q, dim=0)
    all_t = torch.cat(all_t, dim=0)
    return all_q, all_t


@torch.no_grad()
def compute_retrieval_metrics(q_emb, t_emb, k_values=(1, 5, 10, 20)):
    sim = torch.matmul(q_emb, t_emb.t())  # (N, N)
    num_queries = sim.size(0)

    recall_at_k = {k: 0 for k in k_values}
    precision_at_k = {k: 0 for k in k_values}
    mrr = 0.0

    for i in range(num_queries):
        ranking = torch.argsort(sim[i], descending=True)
        correct_idx = i
        position = (ranking == correct_idx).nonzero(as_tuple=True)[0].item()
        rank = position + 1
        mrr += 1.0 / rank

        for k in k_values:
            if rank <= k:
                recall_at_k[k] += 1
                precision_at_k[k] += 1.0 / k

    mrr /= num_queries
    for k in k_values:
        recall_at_k[k] /= num_queries
        precision_at_k[k] /= num_queries

    return mrr, recall_at_k, precision_at_k

In [None]:
from IPython.display import clear_output

# ============================================================
# 8. –¢—Ä–µ–Ω–∏—Ä–æ–≤–∫–∞ —Å –≤–∞–ª–∏–¥–∞—Ü–∏–µ–π –∏ –≥—Ä–∞—Ñ–∏–∫–æ–º
# ============================================================

print("\n" + "="*70)
print("START FUSION TRAINING")
print("="*70)

train_losses = []
val_losses = []
lrs = []

step_losses = []        # –ª–æ—Å—Å –Ω–∞ –∫–∞–∂–¥–æ–º optimizer step
global_step = 0

for epoch in range(cfg.epochs):
    print(f"\nEpoch {epoch+1}/{cfg.epochs}")
    fusion_encoder.train()
    total_loss = 0.0

    for batch_idx, batch in enumerate(tqdm(train_loader, desc="Training")):
        optimizer.zero_grad(set_to_none=True)

        with torch.cuda.amp.autocast(enabled=(device.type == "cuda"),
                                     dtype=torch.bfloat16):
            q = fusion_encoder(batch['desc_audio_emb'].to(device),
                               batch['desc_text_emb'].to(device))
            t = fusion_encoder(batch['audio_emb'].to(device),
                               batch['lyrics_emb'].to(device))
            loss = clip_loss(q, t, cfg.temperature)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        total_loss += loss.item()
        step_losses.append(loss.item())
        global_step += 1

        # –∫–∞–∂–¥—ã–µ 50 —à–∞–≥–æ–≤ ‚Äî –æ–±–Ω–æ–≤–ª—è–µ–º –≥—Ä–∞—Ñ–∏–∫ train loss
        if global_step % 50 == 0:
            clear_output(wait=True)
            plt.figure(figsize=(10, 4))
            plt.plot(step_losses, label="Train loss per step")
            plt.xlabel("Optimizer step")
            plt.ylabel("Loss")
            plt.title(f"Train loss (epoch {epoch+1}, step {global_step})")
            plt.grid(alpha=0.3)
            plt.legend()
            plt.show()

    train_loss = total_loss / len(train_loader)
    train_losses.append(train_loss)
    lrs.append(optimizer.param_groups[0]["lr"])
    print(f"Train loss: {train_loss:.4f}")

    # Val
    val_loss = evaluate_val()
    val_losses.append(val_loss)
    print(f"Val loss: {val_loss:.4f}")

    # Save best
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(fusion_encoder.state_dict(), best_path)
        print("üèÜ New best fusion model saved")

# –ì—Ä–∞—Ñ–∏–∫–∏ –ø–æ —ç–ø–æ—Ö–∞–º (PNG)
plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
plt.plot(train_losses, label="Train")
plt.plot(val_losses, label="Val")
plt.xlabel("Epoch"); plt.ylabel("Loss")
plt.title("Fusion Train/Val Loss")
plt.grid(alpha=0.3); plt.legend()

plt.subplot(1,2,2)
plt.plot(lrs, label="LR")
plt.yscale("log")
plt.xlabel("Epoch"); plt.ylabel("LR")
plt.title("Learning rate")
plt.grid(alpha=0.3); plt.legend()

plt.tight_layout()
plt.savefig(os.path.join(cfg.out_dir, "fusion_losses.png"))
print("‚úì Saved loss curves to fusion_losses.png")

In [None]:
# ============================================================
# 9. –û—Ü–µ–Ω–∫–∞ –Ω–∞ test: MRR, Recall@K, Precision@K
# ============================================================

print("\n" + "="*70)
print("EVALUATING ON TEST SET")
print("="*70)

fusion_encoder.load_state_dict(torch.load(best_path, map_location=device))
fusion_encoder.to(device).eval()
print("‚úì Loaded best fusion model")

q_emb, t_emb = compute_embeddings(test_loader)
print("Embeddings:", q_emb.shape, t_emb.shape)

mrr, recall_at_k, precision_at_k = compute_retrieval_metrics(q_emb, t_emb)

print("\n" + "="*70)
print("FUSION RETRIEVAL METRICS (Test)")
print("="*70)
print(f"MRR: {mrr:.4f}")
for k in [1, 5, 10, 20]:
    print(f"Recall@{k}:    {recall_at_k[k]:.4f}")
    print(f"Precision@{k}: {precision_at_k[k]:.4f}")

In [None]:
from tqdm import tqdm
import torchaudio, os

bad = []

for i in tqdm(range(len(train_ds))):
    item = train_ds.data[i]
    tid = item['track_id']
    path = os.path.join(AUDIO_ROOT, f"{tid}.mp3")
    if not os.path.exists(path):
        bad.append((tid, "missing"))
        continue
    try:
        wav, sr = torchaudio.load(path)
    except Exception as e:
        bad.append((tid, str(e)))

print("Bad files:", len(bad))
bad[:10]
