In [1]:
# Import Libraries
import os, sys, findspark, numpy as np
from pyspark.sql import SparkSession
from pyspark.sql import functions as F, Window as W

In [2]:
# Uses Java 17 & Python 3.11
os.environ["JAVA_HOME"] = "/opt/homebrew/Cellar/openjdk@17/17.0.17/libexec/openjdk.jdk/Contents/Home"
os.environ["PYSPARK_PYTHON"] = sys.executable
os.environ["PYSPARK_DRIVER_PYTHON"] = sys.executable
# Builds PySpark session for 4 local cores with 10GB RAM
findspark.init()
spark = (
    SparkSession.builder
    .appName("SpotifyRec")
    .master("local[4]")
    .config("spark.driver.memory", "10g")
    .config("spark.sql.adaptive.enabled", "true")
    .getOrCreate()
)
# Remove error logs for cleaner output
spark.sparkContext.setLogLevel("ERROR")

Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
25/11/10 20:48:57 WARN Utils: Your hostname, Ethans-MacBook-Pro.local, resolves to a loopback address: 127.0.0.1; using 100.64.14.129 instead (on interface en0)
25/11/10 20:48:57 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/11/10 20:48:58 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
# Interactions DataFrame
playlist_tracks = spark.read.parquet("parquet_data/playlist_tracks")
playlist_tracks.createOrReplaceTempView("playlist_tracks")

# User DataFrame
playlists = spark.read.parquet("parquet_data/playlists")
playlists.createOrReplaceTempView("playlists")

# Items DataFrame
tracks = spark.read.parquet("parquet_data/tracks")
tracks.createOrReplaceTempView("tracks")

# Item Features
# Read tracks_features parquet for model training, create a temporary SQL table
track_features = spark.read.parquet('parquet_data/track_features')
track_features.createOrReplaceTempView("track_features")

# User Features
# Read playlist_features parquet for model training, create a temporary SQL table
playlist_features = spark.read.parquet('parquet_data/playlist_features')
playlist_features.createOrReplaceTempView("playlist_features")

# Item-User Interactions
# Read edges parquet for model training, create a temporary SQL table
edges = spark.read.parquet('parquet_data/edges')
edges.createOrReplaceTempView("edges")

# Training Interactions
train_pairs = spark.read.parquet('parquet_data/train_pairs')
# Validation Interactions
val_pairs = spark.read.parquet('parquet_data/val_pairs')

In [4]:
print(train_pairs.show(5))
print(val_pairs.show(5))
print(edges.show(5))
print(track_features.show(5))
print(playlist_features.show(5))

+---+--------------------+--------------------+-------+
|pid|              tokens|                mask|pos_tid|
+---+--------------------+--------------------+-------+
|  2|[2055362, 902534,...|[1, 1, 1, 1, 1, 1...| 285459|
|  8|[953644, 1226541,...|[1, 1, 1, 1, 1, 1...| 909553|
| 11|[1279865, 1780757...|[1, 1, 1, 1, 1, 1...| 721318|
| 35|[2102277, 1427277...|[1, 1, 1, 1, 1, 1...| 372978|
| 49|[73755, 1607709, ...|[1, 1, 1, 1, 1, 1...| 152751|
+---+--------------------+--------------------+-------+
only showing top 5 rows
None
+---+--------------------+--------------------+-------+
|pid|              tokens|                mask|pos_tid|
+---+--------------------+--------------------+-------+
|192|[2137547, 747563,...|[1, 1, 1, 1, 1, 1...| 146944|
|249|[1012256, 201022,...|[1, 1, 1, 1, 1, 1...|1715753|
|337|[1385456, 338913,...|[1, 1, 1, 1, 1, 1...| 712726|
|352|[598275, 1742374,...|[1, 1, 1, 1, 1, 1...|1311226|
|381|[2126317, 877997,...|[1, 1, 1, 1, 1, 1...|1856934|
+---+--------------

In [5]:
import pyarrow.dataset as ds
import torch

In [6]:
# Loads item-tower features from parquet into PyArrow table, converts select columns into NumPy
PARQUET_T_FEATURES = 'parquet_data/track_features'

ttab = ds.dataset(PARQUET_T_FEATURES, format='parquet').to_table(
    columns=['tid', 'aid', 'alid', 'z_log_track_cnt', 'z_log_artist_cnt'],
    use_threads=True
)
tid = np.array(ttab['tid'], dtype=np.int64)
aid = np.array(ttab['aid'], dtype=np.int64)
alid = np.array(ttab['alid'], dtype=np.int64)
# tnum stacks the numeric features: track_cnt & artist_cnt
tnum = np.stack([
    np.array(ttab['z_log_track_cnt'], dtype=np.float32),
    np.array(ttab['z_log_artist_cnt'], dtype=np.float32)
], axis=1)

# PyTorch runs fastest with dense, index-based lookups. These build dense arrays all indexed by tid
# Shapes:
# tid.shape == (N_tracks, )
# aid.shape == (N_tracks, )
# alid.shape == (N_tracks, )
# tnum.shape == (N_tracks, 2)

In [7]:
# Finds the max counts of tracks, artists, and albums
# Adds 1 to find the embedding sizes, because we use padding_idx=0 later
n_tracks = int(tid.max()) + 1
n_artists = int(aid.max()) + 1
n_albums = int(alid.max()) + 1

In [8]:
# Build dense lookup tables by tid, with length n_tracks
aid_by_tid = np.zeros((n_tracks,), dtype=np.int64)
alid_by_tid = np.zeros((n_tracks,), dtype=np.int64)
tnum_by_tid = np.zeros((n_tracks, 2), dtype=np.float32)

aid_by_tid[tid] = aid
alid_by_tid[tid] = alid
tnum_by_tid[tid] = tnum
# Index is tid, column is either aid, alid, or numeric features
# table[tid] = 0 for all (padding)

In [9]:
# Convert lookup tables to Torch tensors
aid_by_tid = torch.from_numpy(aid_by_tid)
alid_by_tid = torch.from_numpy(alid_by_tid)
tnum_by_tid = torch.from_numpy(tnum_by_tid)
# Vectorize for O(1) lookup
# Wraps the NumPy arrays as torch tensors without copying them
# Do this so the model's forward can do:
# a = self.artist_emb(self.aid_by_tid[pos_ids]) 

In [10]:
# User Tower
# Read playlist_features into a dense per-pid matrix
# 4 logged statistical features
# 1 logged recency feature
# collaborative flag converted to float
PARQUET_P_FEATURES = 'parquet_data/playlist_features'
ptab = ds.dataset(PARQUET_P_FEATURES, format='parquet').to_table()
pid = np.array(ptab['pid'])
pf = np.stack([
    np.array(ptab['log_n_tracks'], dtype=np.float32),
    np.array(ptab['log_n_artists'], dtype=np.float32),
    np.array(ptab['log_n_albums'], dtype=np.float32),
    np.array(ptab['log_pl_duration'], dtype=np.float32),
    np.array(ptab['log_days_mod'], dtype=np.float32),
    np.array(ptab['collaborative'], dtype=np.float32),
], axis=1)

n_pids = int(pid.max()) + 1
pl_feat_by_pid = np.zeros((n_pids, pf.shape[1]), dtype=np.float32)
pl_feat_by_pid[pid] = pf
pl_feat_by_pid = torch.from_numpy(pl_feat_by_pid)

# Shape: (n_pids, 6)
# You want to fuse the sequence pooled embedding (from tokens) with the playlist metadata (pl_feat_by_pid)
# Vectorize for O(1) lookup

In [11]:
# Sets paths for pairs data
PARQUET_TRAIN = 'parquet_data/train_pairs'
PARQUET_VAL = 'parquet_data/val_pairs'

# Set training constants 
max_length = 20
embed_dim = 64
batch_size = 512  # raise if have GPU RAM (512 default)
# Looks for the best GPUs to train on
DEVICE = (
    'cuda' if torch.cuda.is_available()
    else ('mps' if torch.backends.mps.is_available() else 'cpu')  # CUDA first, Apple MPS, else CPU
)

In [12]:
# Defines a custom PyTorch Dataset 
# Implements __len__()
# Implements __getitem__(i)
class PlaylistPairDataset(torch.utils.data.Dataset):
    # Loads a Parquet dataset using PyArrow
    def __init__(self, parquet_path):
        table = ds.dataset(parquet_path, format='parquet').to_table(columns=['pid', 'tokens', 'mask', 'pos_tid'])
        # Converts each column to NumPy array 
        self.pid = np.asarray(table['pid']).astype(np.int64)
        self.tokens = np.stack(table['tokens'].to_pylist()).astype(np.int64)
        self.mask = np.stack(table['mask'].to_pylist()).astype(np.float32)
        self.pos = np.asarray(table['pos_tid']).astype(np.int64)
    # Tells PyTorch how many total training samples there are
    def __len__(self):
        return self.pos.shape[0]
    # Convert one row(a playlist) into a Torch tensor
    def __getitem__(self, i):
        return {
            'pid': torch.tensor(self.pid[i], dtype=torch.long),
            'tokens': torch.from_numpy(self.tokens[i]),
            'mask': torch.from_numpy(self.mask[i]),
            'pos': torch.tensor(self.pos[i], dtype=torch.long)
        }
        # Returns a dictionary of pid, tokens, mask, and pos 

In [13]:
# Creates a PlaylistPairDataset object
# Creates a DataLoader that will feed data batches to model
def make_loader(path, batch_size=batch_size, shuffle=True, num_workers=0):
    ds_ = PlaylistPairDataset(path)

    # Base kwargs that are always OK
    loader_kwargs = dict(
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=(DEVICE == 'cuda'),
        drop_last=True,
    )

    # Only attach multiprocessing-specific knobs when workers > 0
    if num_workers > 0:
        loader_kwargs.update(
            persistent_workers=True,
            prefetch_factor=2
        )

    return ds_, torch.utils.data.DataLoader(ds_, **loader_kwargs)


In [14]:
# Returns a tuple with the dataset, and dataloader
# _ds is the underlying dataset with all playlist/tracks
# _loader is the iterable that creates batches of size batch_size
train_ds, train_loader = make_loader(PARQUET_TRAIN, shuffle=True, num_workers=0)
val_ds,   val_loader   = make_loader(PARQUET_VAL,   shuffle=False, num_workers=0)

In [15]:
import torch.nn as nn
import torch.nn.functional as F

# Constructor, dual-encoder
# Both towers output vectors in the same space so dot products measure the relevance
# n_* are vocab sizes for embeddings
# *_by_tid & *_feat_by_pid are lookup tensors previously built
class TwoTower(nn.Module):
    def __init__(self, 
                 n_tracks,
                 n_artists,
                 n_albums,
                 aid_by_tid,
                 alid_by_tid,
                 tnum_by_tid,
                 pl_feat_by_pid, 
                 embed_dim=64, 
                 mlp_hidden=128, 
                 share_embed=True,
                 cat_dim=32):
        super().__init__()
        
        
        # Lookup buffers
        # Attach read-only tensors to the module so they move with .to(DEVICE) and get saved in checkpoints
        self.register_buffer('aid_by_tid', aid_by_tid.long())
        self.register_buffer('alid_by_tid', alid_by_tid.long())
        self.register_buffer('tnum_by_tid', tnum_by_tid.float())
        self.register_buffer('pl_feat_by_pid', pl_feat_by_pid.float())
        # Performance: Keeping buffers on device uses device memory, if memory issues, keep them on CU and move the per-batch slices to device in forward
        
        
        # ID embedding tables
        # Turns the mapping IDs to dense vectors
        self.item_emb = nn.Embedding(n_tracks, embed_dim, padding_idx=0)
        self.user_emb = self.item_emb if share_embed else nn.Embedding(n_tracks, embed_dim, padding_idx=0)
        # Performance: vocab_size * embed_dim * 4 bytes per table
        # share_embed halves the memory 
        
        
        # Categorical Embeddings for side IDs
        # Turns additional IDs to dense vectors
        self.artist_emb = nn.Embedding(n_artists, cat_dim, padding_idx=0)
        self.album_emb = nn.Embedding(n_albums, cat_dim, padding_idx=0)
        # Performance: adds more memory to process, could drop album if needed
        
        
        # Numeric Feature MLPs
        # Maps numeric item features to the same embedding dimension so they can be fused with ID embeddings
        self.item_num_mlp = nn.Sequential(
            nn.Linear(2, mlp_hidden//2), nn.ReLU(), nn.Linear(mlp_hidden//2, embed_dim)
        )
        # Maps numeric user features
        self.user_num_mlp = nn.Sequential(
            nn.Linear(pl_feat_by_pid.size(1), mlp_hidden), nn.ReLU(), nn.Linear(mlp_hidden, embed_dim)
        )
        
        
        # Fuses features to the same embedding dimension to project to the final vector space
        # Item Tower: base track embedding + artist dimension + album dimension + item numeric projection, projected back to embedding dimension
        self.item_fuse = nn.Sequential(
            nn.Linear(embed_dim + cat_dim + cat_dim + embed_dim, embed_dim),
            nn.LayerNorm(embed_dim)
        )
        
        # User Tower: pooled sequence + playlist numeric projection, projected back to embedding dimension
        self.user_fuse = nn.Sequential(
            nn.Linear(embed_dim + embed_dim, embed_dim),
            nn.LayerNorm(embed_dim)
        )
        # Learns a weighted mixture of sources into a single vector
        

        # Trainable temperature
        # Scale factor in the softmax (InfoNCE). 
        # Lower temperature, sharper distribution; higher temperature, smoother distribution
        self.tau = nn.Parameter(torch.tensor(0.07))
     
        
    # USER TOWER
    # Embeds each track id (token), then masked mean pool to a single vector per playlist  
    def playlist_forward(self, pids, tokens, mask):
        # pooled sequence embedding
        e = self.user_emb(tokens)
        e = e * mask.float().unsqueeze(-1)
        denom = mask.sum(dim=1, keepdims=True).clamp_min(1.0)
        pooled = e.sum(dim=1) / denom 
        seq_vec = F.normalize(pooled, dim=-1) # Normalize to unit length so dot products approximate cos similarity
        
        
        # Playlist numeric features
        pf = self.pl_feat_by_pid[pids]
        pf_vec = F.normalize(self.user_num_mlp(pf), dim=-1) # Encode through MLP
        fused = torch.cat([seq_vec, pf_vec], dim=-1) # Fuse with the sequence vector and project back to embedding dimension
        return F.normalize(self.user_fuse(fused), dim=-1) # Normalize again
    
    
    # ITEM TOWER
    # Gather all per-item features, fuse, project, and normalize
    def track_forward(self, pos_ids):
        # Gathering per-itme features
        base = self.item_emb(pos_ids)
        a = self.artist_emb(self.aid_by_tid[pos_ids])
        al = self.album_emb(self.alid_by_tid[pos_ids])
        xn = self.item_num_mlp(self.tnum_by_tid[pos_ids])
        fused = torch.cat([base, a, al, xn], dim=-1) # Fusing & projected back to embedding dimension
        return F.normalize(self.item_fuse(fused), dim=-1) # Normalize again
    
    
    # Scoring the model
    # Computes a BxB similarity matrix between playlists & items in the same batch (in-batch negatives)
    # Classic contrastive training where: diagonal is true pair, off-diagonals are negatives
    def forward(self, pids, tokens, mask, pos):
        p = self.playlist_forward(pids, tokens, mask)
        t = self.track_forward(pos)
        logits = (p @ t.T) / torch.clamp(self.tau, min=1e-3, max=1.0)
        return logits
    # Performance O(B^2*D), very expensive 


# item_emb uses a lot of memory, reducing embed_dim cuts memory linearly
# Avoid BxB matrices for large patches, continue to use sampled negatives
# Consider sparse=True, SparseAdam. Reduces optimizer state memory, speeds up training on large vocabularies

In [16]:
def cross_entropy_multipos(logits: torch.Tensor, pos_ids: torch.Tensor) -> torch.Tensor:
    """
    logits: (B,B). Row i = playlist i. Column j = item for row j (the positive of j).
    pos_ids: (B,) tid for each row's positive.
    Allows multiple positives per row if the same tid appears in other columns.
    """
    B = logits.size(0)
    col_pos = pos_ids.view(1, -1)          # (1,B)   tids by column
    row_pos = pos_ids.view(-1, 1)          # (B,1)   tids by row
    pos_mask = (row_pos == col_pos)        # (B,B)   True if column is also positive for this row

    log_denom = torch.logsumexp(logits, dim=1)            # (B,)
    logits_pos = logits.masked_fill(~pos_mask, float('-inf'))
    log_num   = torch.logsumexp(logits_pos, dim=1)        # (B,)
    return (log_denom - log_num).mean()

In [17]:
model = TwoTower(
    n_tracks=n_tracks,
    n_artists=n_artists,
    n_albums=n_albums,
    aid_by_tid=aid_by_tid,
    alid_by_tid=alid_by_tid,
    tnum_by_tid=tnum_by_tid,
    pl_feat_by_pid=pl_feat_by_pid,
    embed_dim=128,
    mlp_hidden=128,
    cat_dim=64
).to(DEVICE)

In [18]:
import numpy as np, pyarrow.dataset as ds
chk = ds.dataset('parquet_data/train_pairs', format='parquet').to_table(columns=['tokens','pos_tid'])
tokens0 = any(0 in row for row in chk['tokens'].to_pylist())
pos0    = np.any(np.array(chk['pos_tid']) == 0)
print("tokens contain 0:", tokens0, " | pos contains 0:", pos0)

tokens contain 0: True  | pos contains 0: False


In [19]:
import time

# --- Utilities ---

# Copmutes Recall@K using in-batch negatives
def inbatch_recall_at_k(logits: torch.Tensor, k_list=(10, 50)):
    """
    logits: (B, B) where diagonal is the positive.
    Returns dict of Recall@K over the batch.
    """
    # Batch size. For in-batch InfoNCE, logits are BxB
    B = logits.size(0)
    # If K > B, cap it to B
    k_list = [min(k, B) for k in k_list]
    # Top-k indices per row
    # Gets the indices of the top-K predicted columns, checks whether the true item is within the top-K
    _, topk = logits.topk(k=max(k_list), dim=1)
    # The correct column index for row is "i", is "i", (diagonal is positive)
    targets = torch.arange(B, device=logits.device)
    # A boolean matrix (BxKmax) that marks if the diagonal target appears in top-K
    hits = (topk == targets.view(-1, 1))
    # For each K, compute fraction of rows that the target is in top-K
    metrics = {}
    for k in k_list:
        metrics[f"Recall@{k}"] = hits[:, :k].any(dim=1).float().mean().item()
    return metrics

# Evaluation pass on DataLoader using in-batch negatives, no gradients
# Measures validation Recall@K & loss without changing weights
@torch.no_grad()
def eval_inbatch(model, loader, device=DEVICE, ks=(10, 50)):
    # Eval mode disables dropout, uses eval-time behavior. This is consistent & deterministic
    model.eval()
    # Computes running totals for weighted averaging across batches
    tot = 0
    sum_r = {f"Recall@{min(k, loader.batch_size if hasattr(loader, 'batch_size') else k)}": 0.0 for k in ks}
    sum_loss = 0.0
    # Iterates over validation mini-batches
    # Moves tensors to run on MPS/CPU
    for batch in loader:
        pids   = batch['pid'].to(device, non_blocking=True)
        tokens = batch['tokens'].to(device, non_blocking=True)
        mask   = batch['mask'].to(device, non_blocking=True)
        pos    = batch['pos'].to(device, non_blocking=True)
        # Forward pass returns BxB similarity matrix
        # Diagonals are positive, off-diagonals are negatives
        logits = model(pids, tokens, mask, pos)   # (B, B), model already divides by tau
        # Batch size.
        B = logits.size(0)
        # Creates target class
        y = torch.arange(B, device=device)
        # InfoNCE = CE over softmax of each row with target on the diagonal
        loss = F.cross_entropy(logits, y)
        # Compute Recall@K for this batch
        recalls = inbatch_recall_at_k(logits, ks)

        # Computes the sample-weighted sums to get dataset averages
        batch_sz = B
        tot += batch_sz
        sum_loss += loss.item() * batch_sz
        for k, v in recalls.items():
            sum_r[k] += v * batch_sz

    # Average across all validation samples
    # Returns dict with metrics + mean CE loss
    out = {k: (v / tot) for k, v in sum_r.items()}
    out['loss'] = sum_loss / tot
    return out

# --- One epoch of training with in-batch InfoNCE ---

def train_one_epoch(model, loader, optimizer, device=DEVICE, log_every=100):
    model.train()
    t0 = time.time()
    num_samples = 0
    loss_meter = 0.0
    rec10_meter = 0.0
    rec50_meter = 0.0
    steps = 0

    for i, batch in enumerate(loader):
        pids   = batch['pid'].to(device, non_blocking=True)
        tokens = batch['tokens'].to(device, non_blocking=True)
        mask   = batch['mask'].to(device, non_blocking=True)
        pos    = batch['pos'].to(device, non_blocking=True)

        logits = model(pids, tokens, mask, pos)      # (B, B)
        B = logits.size(0)
        y = torch.arange(B, device=device)

        # InfoNCE (playlist->item)
        loss_pi = F.cross_entropy(logits, y)

        # Optional symmetric loss (item->playlist) gives a small boost:
        loss_ip = F.cross_entropy(logits.t(), y)

        loss = (loss_pi + loss_ip) * 0.5

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        optimizer.step()

        # Meters
        with torch.no_grad():
            recalls = inbatch_recall_at_k(logits, (10, 50))
        bs = B
        num_samples += bs
        loss_meter += loss.item() * bs
        rec10_meter += recalls['Recall@10'] * bs
        rec50_meter += recalls['Recall@50'] * bs
        steps += 1

        if (i + 1) % log_every == 0:
            elapsed = time.time() - t0
            ips = num_samples / max(elapsed, 1e-6)
            print(f"[step {i+1}] "
                  f"loss={loss_meter/num_samples:.4f} "
                  f"R@10={rec10_meter/num_samples:.4f} "
                  f"R@50={rec50_meter/num_samples:.4f} "
                  f"ips={ips:.1f} samples/s")

    elapsed = time.time() - t0
    ips = num_samples / max(elapsed, 1e-6)
    return {
        "train_loss": loss_meter / max(num_samples, 1),
        "train_R@10": rec10_meter / max(num_samples, 1),
        "train_R@50": rec50_meter / max(num_samples, 1),
        "throughput_sps": ips
    }


In [20]:
from torch.optim.lr_scheduler import SequentialLR, CosineAnnealingLR, LinearLR

def build_param_groups(model):
    no_decay, decay = [], []
    for n,p in model.named_parameters():
        if any(k in n for k in ['item_emb','user_emb','artist_emb','album_emb']):
            no_decay.append(p)
        else:
            decay.append(p)
    return [
        {'params': decay, 'weight_decay': 1e-4},
        {'params': no_decay, 'weight_decay': 0.0},
    ]

def build_scheduler(optimizer, total_steps, warmup_steps=1000, min_lr=1e-5):
    sched_warm = LinearLR(optimizer, start_factor=0.1, total_iters=warmup_steps)
    sched_cos  = CosineAnnealingLR(optimizer, T_max=max(1, total_steps - warmup_steps), eta_min=min_lr)
    scheduler  = SequentialLR(optimizer, [sched_warm, sched_cos], milestones=[warmup_steps])
    return scheduler

In [21]:
import os, time, math
import torch
import torch.nn.functional as F

def train_epoch_with_accum(model, train_loader, optimizer, scheduler, device, log_every, accum_steps, use_multipos):
    model.train()
    t0 = time.time()
    num_samples = 0
    loss_meter = rec10_meter = rec50_meter = 0.0

    optimizer.zero_grad(set_to_none=True)

    for i, batch in enumerate(train_loader):
        pids   = batch['pid'].to(device, non_blocking=True)
        tokens = batch['tokens'].to(device, non_blocking=True)
        mask   = batch['mask'].to(device, non_blocking=True)
        pos    = batch['pos'].to(device, non_blocking=True)

        logits = model(pids, tokens, mask, pos)  # (B,B)
        B = logits.size(0)
        y = torch.arange(B, device=device)

        if use_multipos:
            loss_pi = cross_entropy_multipos(logits, pos)
            loss_ip = cross_entropy_multipos(logits.t(), pos)
        else:
            loss_pi = F.cross_entropy(logits, y)
            loss_ip = F.cross_entropy(logits.t(), y)

        loss = 0.5*(loss_pi + loss_ip) / accum_steps
        loss.backward()

        if ((i + 1) % accum_steps) == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)
            if scheduler is not None:
                scheduler.step()

        with torch.no_grad():
            recalls = inbatch_recall_at_k(logits, (10, 50))
        bs = B
        num_samples += bs
        loss_meter  += float(loss.item()) * bs * accum_steps   # undo /accum_steps for logging
        rec10_meter += recalls['Recall@10'] * bs
        rec50_meter += recalls['Recall@50'] * bs

        if (i + 1) % log_every == 0:
            elapsed = time.time() - t0
            ips = num_samples / max(elapsed, 1e-6)
            print(f"[step {i+1}] loss={loss_meter/num_samples:.4f} "
                  f"R@10={rec10_meter/num_samples:.4f} R@50={rec50_meter/num_samples:.4f} "
                  f"ips={ips:.1f}/s")

    elapsed = time.time() - t0
    ips = num_samples / max(elapsed, 1e-6)
    return {
        "train_loss": loss_meter / max(num_samples, 1),
        "train_R@10": rec10_meter / max(num_samples, 1),
        "train_R@50": rec50_meter / max(num_samples, 1),
        "throughput_sps": ips
    }

def run_experiment(cfg):
    """
    cfg keys:
      epochs, batch_size, accum_steps, share_embed, use_multipos, base_lr,
      use_cosine, warmup_steps, min_lr, patience, log_every
    """
    print("Config:", cfg)

    # (Re)build loaders if you change batch_size/workers
    _train_ds, train_loader = make_loader(PARQUET_TRAIN, batch_size=cfg['batch_size'], shuffle=True,  num_workers=cfg.get('num_workers', 0))
    _val_ds,   val_loader   = make_loader(PARQUET_VAL,   batch_size=cfg['batch_size'], shuffle=False, num_workers=cfg.get('num_workers', 0))

    # Fresh model per run
    model = TwoTower(
        n_tracks=n_tracks,
        n_artists=n_artists,
        n_albums=n_albums,
        aid_by_tid=aid_by_tid,
        alid_by_tid=alid_by_tid,
        tnum_by_tid=tnum_by_tid,
        pl_feat_by_pid=pl_feat_by_pid,
        embed_dim=embed_dim,
        mlp_hidden=128,
        share_embed=cfg['share_embed'],
        cat_dim=32,
    ).to(DEVICE)

    # Optim + (optional) cosine schedule
    if cfg.get('adamw_split_decay', True):
        opt_params = build_param_groups(model)
    else:
        opt_params = model.parameters()

    optimizer = torch.optim.AdamW(opt_params, lr=cfg['base_lr'], betas=(0.9,0.98))

    if cfg['use_cosine']:
        total_steps = len(train_loader) * cfg['epochs'] // max(1, cfg['accum_steps'])
        scheduler = build_scheduler(optimizer, total_steps, warmup_steps=cfg['warmup_steps'], min_lr=cfg['min_lr'])
    else:
        scheduler = None

    best_val = -1.0
    no_improve = 0
    BEST_PATH = cfg.get('checkpoint_path', 'checkpoints/twotower_best.pt')
    os.makedirs(os.path.dirname(BEST_PATH), exist_ok=True)

    for epoch in range(1, cfg['epochs']+1):
        print(f"\n=== Epoch {epoch}/{cfg['epochs']} (steps/epoch ≈ {len(train_loader)}) ===")
        tr = train_epoch_with_accum(
            model, train_loader, optimizer, scheduler,
            device=DEVICE, log_every=cfg['log_every'],
            accum_steps=cfg['accum_steps'], use_multipos=cfg['use_multipos']
        )
        va = eval_inbatch(model, val_loader, device=DEVICE, ks=(10, 50))

        print(f"Train: loss={tr['train_loss']:.4f} R@10={tr['train_R@10']:.4f} R@50={tr['train_R@50']:.4f} "
              f"throughput={tr['throughput_sps']:.1f}/s  lr={optimizer.param_groups[0]['lr']:.2e}")
        print(f"Valid: loss={va['loss']:.4f}  R@10={va['Recall@10']:.4f}  R@50={va['Recall@50']:.4f}")

        if va['Recall@50'] > best_val + 1e-4:
            best_val = va['Recall@50']
            no_improve = 0
            torch.save({"model": model.state_dict(),
                        "val_R@50": best_val,
                        "epoch": epoch}, BEST_PATH)
            print(f"✓ New best R@50={best_val:.4f} — checkpointed to {BEST_PATH}")
        else:
            no_improve += 1
            print(f"(no improvement for {no_improve} epoch(s))")
            if no_improve >= cfg['patience']:
                print(f"Early stopping at epoch {epoch}.")
                break

    return best_val, BEST_PATH


In [22]:
# Baseline-ish (close to what you had)
cfg_baseline = {
    'epochs': 3,
    'batch_size': 512,
    'accum_steps': 1,
    'share_embed': True,          # tied embeddings
    'use_multipos': False,        # vanilla InfoNCE
    'base_lr': 3e-4,
    'use_cosine': False,
    'warmup_steps': 1000,
    'min_lr': 1e-5,
    'patience': 3,
    'log_every': 100,
    'num_workers': 0,
    'adamw_split_decay': True,
    'checkpoint_path': 'checkpoints/exp_baseline.pt'
}

# Improved (recommended)
cfg_improved = {
    'epochs': 8,
    'batch_size': 512,
    'accum_steps': 4,             # effective batch 2048
    'share_embed': False,         # untie towers
    'use_multipos': True,         # multi-positive InfoNCE
    'base_lr': 1e-3,              # higher base with warmup+cosine
    'use_cosine': True,
    'warmup_steps': 1000,
    'min_lr': 1e-5,
    'patience': 3,
    'log_every': 100,
    'num_workers': 0,
    'adamw_split_decay': True,
    'checkpoint_path': 'checkpoints/exp_improved.pt'
}


In [None]:
best_baseline, path_baseline = run_experiment(cfg_baseline)
print("Baseline best R@50:", best_baseline, "| ckpt:", path_baseline)

best_improved, path_improved = run_experiment(cfg_improved)
print("Improved best R@50:", best_improved, "| ckpt:", path_improved)


In [25]:
import torch
import numpy as np
from tqdm import tqdm

@torch.no_grad()
def build_item_index(model, n_tracks, device='cpu', batch_size=4096):
    """
    Compute the embedding for every track ID (1..n_tracks-1) and store in a dense matrix.
    Returns: (index_embeds, tids)
    """
    model.eval()
    all_embeds, all_ids = [], []

    for start in tqdm(range(1, n_tracks, batch_size), desc="Building item index"):
        end = min(start + batch_size, n_tracks)
        tids = torch.arange(start, end, device=device)
        embeds = model.track_forward(tids)
        all_embeds.append(embeds.cpu())
        all_ids.append(tids.cpu())

    index_embeds = torch.cat(all_embeds, dim=0)
    tids = torch.cat(all_ids, dim=0)
    # Normalize for cosine similarity search
    index_embeds = torch.nn.functional.normalize(index_embeds, dim=-1)
    return index_embeds, tids


@torch.no_grad()
def evaluate_full_corpus(model, val_loader, item_index, ks=(10, 50, 100), device='cpu', score_chunk=4096):
    """
    Memory-safe full retrieval:
      • Item index stays on CPU.
      • We move small item chunks to device.
      • We maintain running top-K per row and merge across chunks.
    """
    model.eval()
    item_embs_cpu, tids_cpu = item_index   # BOTH on CPU
    ks = sorted(set(ks))
    max_k = ks[-1]

    total, hits_at_k = 0, {k: 0 for k in ks}

    for batch in tqdm(val_loader, desc="Evaluating full corpus (streaming)"):
        # move only the query side to device
        pids   = batch['pid'].to(device, non_blocking=True)
        tokens = batch['tokens'].to(device, non_blocking=True)
        mask   = batch['mask'].to(device, non_blocking=True)
        pos    = batch['pos']  # keep on CPU for comparison later

        # (B, D)
        user_vecs = model.playlist_forward(pids, tokens, mask)  # on device
        B, D = user_vecs.shape

        # running top-k per row
        topk_scores = None   # (B, max_k) on device
        topk_indices = None  # (B, max_k) on CPU (tids)

        # stream item chunks
        N = item_embs_cpu.size(0)
        for start in range(0, N, score_chunk):
            end = min(start + score_chunk, N)

            # move item chunk to device for scoring
            chunk_embs = item_embs_cpu[start:end].to(device, non_blocking=True)  # (C, D)
            # scores (B, C)
            scores_chunk = user_vecs @ chunk_embs.T

            # get top-k within this chunk
            vals_chunk, idxs_chunk_local = scores_chunk.topk(k=min(max_k, scores_chunk.size(1)), dim=1)
            # map local chunk idxs to global tids (still CPU)
            tids_chunk = tids_cpu[start:end]  # (C,)
            idxs_chunk_global_tids = tids_chunk[idxs_chunk_local.detach().cpu()]  # (B, k)

            if topk_scores is None:
                topk_scores = vals_chunk      # on device
                topk_indices = idxs_chunk_global_tids  # on CPU
            else:
                # merge existing top-k with this chunk's top-k
                cat_scores = torch.cat([topk_scores, vals_chunk], dim=1)  # device
                cat_indices = torch.cat([topk_indices, idxs_chunk_global_tids], dim=1)  # CPU
                # find new top-k positions
                vals_merged, pos_in_cat = cat_scores.topk(k=max_k, dim=1)
                # gather indices accordingly (need CPU gather; convert pos to CPU)
                pos_in_cat_cpu = pos_in_cat.detach().cpu()
                topk_indices = torch.gather(cat_indices, 1, pos_in_cat_cpu)  # CPU
                topk_scores  = vals_merged  # device

            # free chunk
            del chunk_embs, scores_chunk, vals_chunk, idxs_chunk_local
            if device == 'mps':
                torch.mps.empty_cache()

        # compute hits per K (everything on CPU for comparison)
        pos_cpu = pos.detach().cpu().unsqueeze(1)  # (B,1)
        for k in ks:
            hits = (topk_indices[:, :k] == pos_cpu).any(dim=1).sum().item()
            hits_at_k[k] += hits

        total += B

        # free query batch
        del user_vecs, topk_scores, topk_indices
        if device == 'mps':
            torch.mps.empty_cache()

    return {f"Recall@{k}": hits_at_k[k] / total for k in ks}

In [None]:
# Load best checkpoint (replace path if needed)
ckpt = torch.load("checkpoints/exp_improved.pt", map_location=DEVICE)

best_model = TwoTower(
    n_tracks=n_tracks,
    n_artists=n_artists,
    n_albums=n_albums,
    aid_by_tid=aid_by_tid,
    alid_by_tid=alid_by_tid,
    tnum_by_tid=tnum_by_tid,
    pl_feat_by_pid=pl_feat_by_pid,
    embed_dim=embed_dim,
    mlp_hidden=128,
    share_embed=False,
    cat_dim=32
).to(DEVICE)
best_model.load_state_dict(ckpt["model"])
best_model.eval()

# Build the item index (takes ~1–2 min)
item_index = build_item_index(best_model, n_tracks, device=DEVICE, batch_size=4096)

# Evaluate full Recall@K
recalls = evaluate_full_corpus(best_model, val_loader, item_index, ks=[10,50,100], device=DEVICE, score_chunk=8192)
print("Full-corpus recall:", {k: round(v, 4) for k, v in recalls.items()})


Building item index: 100%|██████████| 553/553 [00:00<00:00, 854.68it/s] 
Evaluating full corpus (streaming): 100%|██████████| 98/98 [33:06<00:00, 20.27s/it]

Full-corpus recall: {'Recall@10': 0.0161, 'Recall@50': 0.0534, 'Recall@100': 0.0829}





25/11/11 10:17:57 ERROR Inbox: Ignoring error
org.apache.spark.SparkException: Exception thrown in awaitResult: 
	at org.apache.spark.util.SparkThreadUtils$.awaitResult(SparkThreadUtils.scala:53)
	at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:342)
	at org.apache.spark.rpc.RpcTimeout.awaitResult(RpcTimeout.scala:75)
	at org.apache.spark.rpc.RpcEnv.setupEndpointRefByURI(RpcEnv.scala:102)
	at org.apache.spark.rpc.RpcEnv.setupEndpointRef(RpcEnv.scala:110)
	at org.apache.spark.util.RpcUtils$.makeDriverRef(RpcUtils.scala:36)
	at org.apache.spark.storage.BlockManagerMasterEndpoint.driverEndpoint$lzycompute(BlockManagerMasterEndpoint.scala:132)
	at org.apache.spark.storage.BlockManagerMasterEndpoint.org$apache$spark$storage$BlockManagerMasterEndpoint$$driverEndpoint(BlockManagerMasterEndpoint.scala:131)
	at org.apache.spark.storage.BlockManagerMasterEndpoint.isExecutorAlive$lzycompute$1(BlockManagerMasterEndpoint.scala:700)
	at org.apache.spark.storage.BlockManagerMasterE

In [None]:
print("Final tau:", float(model.tau.detach().cpu()))
print("Item/user emb norms:",
      model.item_emb.weight.norm(dim=1).mean().item(),
      (model.user_emb.weight if hasattr(model, 'user_emb') else model.item_emb.weight).norm(dim=1).mean().item())


In [None]:
# --- Few-epoch smoke test ---

# Optimizer: AdamW is a good default; SparseAdam is an option if you set sparse=True on embeddings.
optimizer = torch.optim.AdamW(model.parameters(), lr=.001, weight_decay=.001)

EPOCHS = 3
for epoch in range(1, EPOCHS + 1):
    print(f"\n=== Epoch {epoch}/{EPOCHS} ===")
    tr = train_one_epoch(model, train_loader, optimizer, device=DEVICE, log_every=50)
    va = eval_inbatch(model, val_loader, device=DEVICE, ks=(10, 50))
    print(f"Train: loss={tr['train_loss']:.4f}  R@10={tr['train_R@10']:.4f}  R@50={tr['train_R@50']:.4f}  "
          f"throughput={tr['throughput_sps']:.1f} samples/s")
    print(f"Valid: loss={va['loss']:.4f}  R@10={va['Recall@10']:.4f}  R@50={va['Recall@50']:.4f}")

In [None]:
import torch, os

BEST_PATH = "checkpoints/twotower_best.pt"
os.makedirs(os.path.dirname(BEST_PATH), exist_ok=True)

optimizer = torch.optim.AdamW(model.parameters(), lr=.001, weight_decay=.001)
# LR down when val Recall@50 stalls
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=0.5, patience=1, min_lr=1e-6
)

max_epochs = 20          # hard cap; early stopping will usually finish earlier
patience = 3             # stop if no val improvement for this many epochs
best_val = -1.0
no_improve = 0

for epoch in range(1, max_epochs+1):
    print(f"\n=== Epoch {epoch}/{max_epochs} (steps/epoch ≈ {len(train_loader)}) ===")
    tr = train_one_epoch(model, train_loader, optimizer, device=DEVICE, log_every=100)
    va = eval_inbatch(model, val_loader, device=DEVICE, ks=(10, 50))

    val_r50 = va['Recall@50']  # primary metric
    scheduler.step(val_r50)

    print(f"Train: loss={tr['train_loss']:.4f} R@10={tr['train_R@10']:.4f} R@50={tr['train_R@50']:.4f} "
          f"throughput={tr['throughput_sps']:.1f}/s  lr={optimizer.param_groups[0]['lr']:.2e}")
    print(f"Valid: loss={va['loss']:.4f}  R@10={va['Recall@10']:.4f}  R@50={val_r50:.4f}")

    # checkpoint + early stop
    if val_r50 > best_val + 1e-4:       # small epsilon to avoid tiny noise
        best_val = val_r50
        no_improve = 0
        torch.save({"model": model.state_dict(),
                    "val_R@50": best_val,
                    "epoch": epoch}, BEST_PATH)
        print(f"✓ New best R@50={best_val:.4f} — checkpointed to {BEST_PATH}")
    else:
        no_improve += 1
        print(f"(no improvement for {no_improve} epoch(s))")
        if no_improve >= patience:
            print(f"Early stopping: no val R@50 improvement for {patience} epochs.")
            break

# (Optional) Load best before final evaluation or indexing
# ckpt = torch.load(BEST_PATH, map_location=DEVICE)
# model.load_state_dict(ckpt["model"])
