In [1]:
# Import Libraries
import os, sys, findspark, numpy as np
from pyspark.sql import SparkSession
from pyspark.sql import functions as F, Window as W

In [2]:
# Uses Java 17 & Python 3.11
os.environ["JAVA_HOME"] = "/opt/homebrew/Cellar/openjdk@17/17.0.17/libexec/openjdk.jdk/Contents/Home"
os.environ["PYSPARK_PYTHON"] = sys.executable
os.environ["PYSPARK_DRIVER_PYTHON"] = sys.executable
# Builds PySpark session for 4 local cores with 10GB RAM
findspark.init()
spark = (
    SparkSession.builder
    .appName("SpotifyRec")
    .master("local[4]")
    .config("spark.driver.memory", "10g")
    .config("spark.sql.adaptive.enabled", "true")
    .getOrCreate()
)
# Remove error logs for cleaner output
spark.sparkContext.setLogLevel("ERROR")

Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
25/11/10 00:36:43 WARN Utils: Your hostname, Ethans-MacBook-Pro.local, resolves to a loopback address: 127.0.0.1; using 100.64.14.129 instead (on interface en0)
25/11/10 00:36:43 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/11/10 00:36:44 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
# Interactions DataFrame
playlist_tracks = spark.read.parquet("parquet_data/playlist_tracks")
playlist_tracks.createOrReplaceTempView("playlist_tracks")

# User DataFrame
playlists = spark.read.parquet("parquet_data/playlists")
playlists.createOrReplaceTempView("playlists")

# Items DataFrame
tracks = spark.read.parquet("parquet_data/tracks")
tracks.createOrReplaceTempView("tracks")

# Item Features
# Read tracks_features parquet for model training, create a temporary SQL table
track_features = spark.read.parquet('parquet_data/track_features')
track_features.createOrReplaceTempView("track_features")

# User Features
# Read playlist_features parquet for model training, create a temporary SQL table
playlist_features = spark.read.parquet('parquet_data/playlist_features')
playlist_features.createOrReplaceTempView("playlist_features")

# Item-User Interactions
# Read edges parquet for model training, create a temporary SQL table
edges = spark.read.parquet('parquet_data/edges')
edges.createOrReplaceTempView("edges")

# Training Interactions
train_pairs = spark.read.parquet('parquet_data/train_pairs')
# Validation Interactions
val_pairs = spark.read.parquet('parquet_data/val_pairs')

In [4]:
print(train_pairs.show(5))
print(val_pairs.show(5))
print(edges.show(5))
print(track_features.show(5))
print(playlist_features.show(5))

+---+--------------------+--------------------+-------+
|pid|              tokens|                mask|pos_tid|
+---+--------------------+--------------------+-------+
|  2|[2055362, 902534,...|[1, 1, 1, 1, 1, 1...| 285459|
|  8|[953644, 1226541,...|[1, 1, 1, 1, 1, 1...| 909553|
| 11|[1279865, 1780757...|[1, 1, 1, 1, 1, 1...| 721318|
| 35|[2102277, 1427277...|[1, 1, 1, 1, 1, 1...| 372978|
| 49|[73755, 1607709, ...|[1, 1, 1, 1, 1, 1...| 152751|
+---+--------------------+--------------------+-------+
only showing top 5 rows
None
+---+--------------------+--------------------+-------+
|pid|              tokens|                mask|pos_tid|
+---+--------------------+--------------------+-------+
|192|[2137547, 747563,...|[1, 1, 1, 1, 1, 1...| 146944|
|249|[1012256, 201022,...|[1, 1, 1, 1, 1, 1...|1715753|
|337|[1385456, 338913,...|[1, 1, 1, 1, 1, 1...| 712726|
|352|[598275, 1742374,...|[1, 1, 1, 1, 1, 1...|1311226|
|381|[2126317, 877997,...|[1, 1, 1, 1, 1, 1...|1856934|
+---+--------------

In [5]:
import pyarrow.dataset as ds
import torch
import numpy as np

PARQUET_T_FEATURES = 'parquet_data/track_features'
ttab = ds.dataset(PARQUET_T_FEATURES, format='parquet').to_table()
tid = np.array(ttab['tid'], dtype=np.int64)
aid = np.array(ttab['aid'], dtype=np.int64)
alid = np.array(ttab['alid'], dtype=np.int64)
tnum = np.stack([
    np.array(ttab['z_log_track_cnt'], dtype=np.float32),
    np.array(ttab['z_log_artist_cnt'], dtype=np.float32)
], axis=1)
n_tracks = int(tid.max()) + 1
n_artists = int(aid.max()) + 1
n_albums = int(alid.max()) + 1

aid_by_tid = np.zeros((n_tracks,), dtype=np.int64)
alid_by_tid = np.zeros((n_tracks,), dtype=np.int64)
tnum_by_tid = np.zeros((n_tracks, 2), dtype=np.float32)

aid_by_tid[tid] = aid
alid_by_tid[tid] = alid
tnum_by_tid[tid] = tnum

aid_by_tid = torch.from_numpy(aid_by_tid)
alid_by_tid = torch.from_numpy(alid_by_tid)
tnum_by_tid = torch.from_numpy(tnum_by_tid)

PARQUET_P_FEATURES = 'parquet_data/playlist_features'
ptab = ds.dataset(PARQUET_P_FEATURES, format='parquet').to_table()
pid = np.array(ptab['pid'])
pf = np.stack([
    np.array(ptab['log_n_tracks'], dtype=np.float32),
    np.array(ptab['log_n_artists'], dtype=np.float32),
    np.array(ptab['log_n_albums'], dtype=np.float32),
    np.array(ptab['log_pl_duration'], dtype=np.float32),
    np.array(ptab['log_days_mod'], dtype=np.float32),
    np.array(ptab['collaborative'], dtype=np.float32),
], axis=1)

n_pids = int(pid.max()) + 1
pl_feat_by_pid = np.zeros((n_pids, pf.shape[1]), dtype=np.float32)
pl_feat_by_pid[pid] = pf
pl_feat_by_pid = torch.from_numpy(pl_feat_by_pid)

In [6]:
import pyarrow.dataset as ds
import torch

PARQUET_TRAIN = 'parquet_data/train_pairs'
PARQUET_VAL = 'parquet_data/val_pairs'

n_tracks = 2262108 
max_length = 20
embed_dim = 128
batch_size = 512 # raise if have GPU RAM
# Looks for the best GPUs to train on
DEVICE = (
    'cuda' if torch.cuda.is_available()
    else ('mps' if torch.backends.mps.is_available() else 'cpu')
)

class PlaylistPairDataset(torch.utils.data.Dataset):
    def __init__(self, parquet_path):
        table = ds.dataset(parquet_path, format='parquet').to_table(columns=['pid', 'tokens', 'mask', 'pos_tid'])
        self.pid = np.asarray(table['pid']).astype(np.int64)
        self.tokens = np.stack(table['tokens'].to_pylist()).astype(np.int64)
        self.mask = np.stack(table['mask'].to_pylist()).astype(np.float32)
        self.pos = np.asarray(table['pos_tid']).astype(np.int64)
    
    def __len__(self):
        return self.pos.shape[0]
    
    def __getitem__(self, i):
        return {
            'pid': torch.tensor(self.pid[i], dtype=torch.long),
            'tokens': torch.from_numpy(self.tokens[i]),
            'mask': torch.from_numpy(self.mask[i]),
            'pos': torch.tensor(self.pos[i], dtype=torch.long)
        }

def make_loader(path, batch_size=batch_size, shuffle=True, num_workers=0):
    ds_ = PlaylistPairDataset(path)
    return ds_, torch.utils.data.DataLoader(
        ds_, 
        batch_size = batch_size, 
        shuffle=shuffle, 
        num_workers=num_workers,
        pin_memory=(DEVICE=='cuda'), 
        drop_last=True
    )

In [7]:
import torch.nn as nn
import torch.nn.functional as F

class TwoTower(nn.Module):
    def __init__(self, 
                 n_tracks,
                 n_artists,
                 n_albums,
                 aid_by_tid,
                 alid_by_tid,
                 tnum_by_tid,
                 pl_feat_by_pid, 
                 embed_dim=128, 
                 mlp_hidden=256, 
                 share_embed=True,
                 cat_dim=64):
        super().__init__()
        # register lookups as buffers so .to(DEVICE) moves them
        self.register_buffer('aid_by_tid', aid_by_tid.long())
        self.register_buffer('alid_by_tid', alid_by_tid.long())
        self.register_buffer('tnum_by_tid', tnum_by_tid.float())
        self.register_buffer('pl_feat_by_pid', pl_feat_by_pid.float())
        
        # id embeddings
        self.item_emb = nn.Embedding(n_tracks, embed_dim, padding_idx=0)
        self.user_emb = self.item_emb if share_embed else nn.Embedding(n_tracks, embed_dim, padding_idx=0)
        
        # categorical embeddings
        self.artist_emb = nn.Embedding(n_artists, cat_dim, padding_idx=0)
        self.album_emb = nn.Embedding(n_albums, cat_dim, padding_idx=0)
        
        # numeric feature MLPs
        self.item_num_mlp = nn.Sequential(
            nn.Linear(2, mlp_hidden//2), nn.ReLU(), nn.Linear(mlp_hidden//2, embed_dim)
        )
        self.user_num_mlp = nn.Sequential(
            nn.Linear(pl_feat_by_pid.size(1), mlp_hidden), nn.ReLU(), nn.Linear(mlp_hidden, embed_dim)
        )
        # fuse + project
        self.item_fuse = nn.Linear(embed_dim + cat_dim + cat_dim + embed_dim, embed_dim)
        self.user_fuse = nn.Linear(embed_dim + embed_dim, embed_dim)

        self.tau = nn.Parameter(torch.tensor(1.0))
        
    # towers    
    def playlist_forward(self, pids, tokens, mask):
        # pooled sequence embedding
        e = self.user_emb(tokens)
        e = e * mask.unsqueeze(-1)
        denom = mask.sum(dim=1, keepdims=True).clamp_min(1.0)
        pooled = e.sum(dim=1) / denom
        seq_vec = F.normalize(pooled, dim=-1)
        
        # playlist numeric features
        pf = self.pl_feat_by_pid[pids]
        pf_vec = F.normalize(self.user_num_mlp(pf), dim=-1)
        
        fused = torch.cat([seq_vec, pf_vec], dim=-1)
        return F.normalize(self.user_fuse(fused), dim=-1)
    
    def track_forward(self, pos_ids):
        base = self.item_emb(pos_ids)
        a = self.artist_emb(self.aid_by_tid[pos_ids])
        al = self.album_emb(self.alid_by_tid[pos_ids])
        xn = self.item_num_mlp(self.tnum_by_tid[pos_ids])
        fused = torch.cat([base, a, al, xn], dim=-1)
        return F.normalize(self.item_fuse(fused), dim=-1)
    
    def forward(self, pids, tokens, mask, pos):
        p = self.playlist_forward(pids, tokens, mask)
        t = self.track_forward(pos)
        logits = (p @ t.T) / self.tau.clamp_min(.001)
        return logits

model = TwoTower(
    n_tracks=n_tracks,
    n_artists=n_artists,
    n_albums=n_albums,
    aid_by_tid=aid_by_tid,
    alid_by_tid=alid_by_tid,
    tnum_by_tid=tnum_by_tid,
    pl_feat_by_pid=pl_feat_by_pid,
    embed_dim=embed_dim,
    mlp_hidden=256,
    share_embed=True,
    cat_dim=64
).to(DEVICE)
opt = torch.optim.AdamW(model.parameters(), lr=.0003, weight_decay=.0001)
scaler = torch.amp.GradScaler() if DEVICE=='cuda' else None
    
def train_one_epoch():
    model.train(); running=0.0
    for batch in train_loader:
        pids = batch['pid'].to(DEVICE)
        tokens = batch['tokens'].to(DEVICE)
        mask = batch['mask'].to(DEVICE)
        pos = batch['pos'].to(DEVICE)
        opt.zero_grad(set_to_none=True)
        
        if scaler:
            with torch.amp.autocast():
                logits =  model(pids, tokens, mask, pos)
                targets = torch.arange(logits.size(0), device=logits.device)
                loss = F.cross_entropy(logits, targets)
            scaler.scale(loss).backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(opt) 
            scaler.update()
        else:
            logits = model(pids, tokens, mask, pos)
            targets = torch.arange(logits.size(0), device=logits.device)
            loss = F.cross_entropy(logits, targets)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
        running += loss.item()
    return running / max(1, len(train_loader))

def recall_at_k(loader, K=10, pool=2000):
    model.eval()
    cand = []
    for i, b in enumerate(loader):
        cand.extend(b['pos'].tolist())
        if len(cand) >= pool:
            break
    cand = torch.tensor(cand[:pool], device=DEVICE)
    
    hits = n = 0
    for b in loader:
        pids = b['pid'].to(DEVICE)
        tokens = b['tokens'].to(DEVICE) 
        mask = b['mask'].to(DEVICE)
        pos = b['pos'].to(DEVICE)
        
        p = model.playlist_forward(tokens, mask)
        batch = pos.size(0)
        pool_ids = cand.unsqueeze(0).repeat(batch, 1)
        pos = pool_ids[:, 0]
        t = model.track_forward(pool_ids.reshape(-1))
        t = t.view(batch, -1, p.size(1))
        scores = (p.unsqueeze(1) * t).sum(-1)
        topk = scores.topk(K, dim=1).indices
        hits += (topk[:, 0] == 0).sum().item()
        n += batch
    return hits / max(1, n)

In [8]:
import time, math

def bench_loader(loader, steps=200):
    it = iter(loader)
    # warmup a few batches (jit, caches)
    for _ in range(10):
        b = next(it)
    t0 = time.time()
    n = 0
    with torch.no_grad():
        for _ in range(steps):
            try:
                b = next(it)
            except StopIteration:
                it = iter(loader); b = next(it)
            # minimal forward to include model compute cost
            logits = model(
                b['pid'].to(DEVICE),
                b['tokens'].to(DEVICE),
                b['mask'].to(DEVICE),
                b['pos'].to(DEVICE)
            )
            n += 1
    dt = time.time() - t0
    print(f"{n/dt:.2f} batches/sec  | {len(b['pos'])*(n/dt):.0f} samples/sec")

In [9]:
train_ds, train_loader = make_loader(PARQUET_TRAIN, shuffle=True, num_workers=0)
val_ds,   val_loader   = make_loader(PARQUET_VAL,   shuffle=False, num_workers=0)

bench_loader(train_loader)

104.70 batches/sec  | 53607 samples/sec


In [10]:
b = next(iter(train_loader))
with torch.no_grad():
    logits = model(
        b['pid'][:8].to(DEVICE),
        b['tokens'][:8].to(DEVICE),
        b['mask'][:8].to(DEVICE),
        b['pos'][:8].to(DEVICE)
    )
print('logits', logits.shape)  # expect [8, 8]

logits torch.Size([8, 8])
