In [None]:
# Import Libraries
import os, sys, findspark, numpy as np
from pyspark.sql import SparkSession
from pyspark.sql import functions as F, Window as W

In [None]:
# Uses Java 17 & Python 3.11
os.environ["JAVA_HOME"] = "/opt/homebrew/Cellar/openjdk@17/17.0.17/libexec/openjdk.jdk/Contents/Home"
os.environ["PYSPARK_PYTHON"] = sys.executable
os.environ["PYSPARK_DRIVER_PYTHON"] = sys.executable
# Builds PySpark session for 4 local cores with 10GB RAM
findspark.init()
spark = (
    SparkSession.builder
    .appName("SpotifyRec")
    .master("local[4]")
    .config("spark.driver.memory", "10g")
    .config("spark.sql.adaptive.enabled", "true")
    .getOrCreate()
)
# Remove error logs for cleaner output
spark.sparkContext.setLogLevel("ERROR")

In [None]:
# Interactions DataFrame
playlist_tracks = spark.read.parquet("parquet_data/playlist_tracks")
playlist_tracks.createOrReplaceTempView("playlist_tracks")

# User DataFrame
playlists = spark.read.parquet("parquet_data/playlists")
playlists.createOrReplaceTempView("playlists")

# Items DataFrame
tracks = spark.read.parquet("parquet_data/tracks")
tracks.createOrReplaceTempView("tracks")

# Item Features - Track
# Read tracks_features parquet for model training, create a temporary SQL table
track_features = spark.read.parquet('parquet_data/track_features')
track_features.createOrReplaceTempView("track_features")

# Item Features - Artist
# Read artist_features parquet for model training, create a temporary SQL table
artist_features = spark.read.parquet('parquet_data/artist_features')
artist_features.createOrReplaceTempView('artist_features')

# User Features
# Read playlist_features parquet for model training, create a temporary SQL table
playlist_features = spark.read.parquet('parquet_data/playlist_features')
playlist_features.createOrReplaceTempView("playlist_features")

# Item-User Interactions
# Read edges parquet for model training, create a temporary SQL table
edges = spark.read.parquet('parquet_data/edges')
edges.createOrReplaceTempView("edges")

In [None]:
edges.printSchema()

In [None]:
edges_shift = edges.withColumn("tid", F.col("tid") + 1)
edges_shift.createOrReplaceTempView('edges_shift')
track_features_shift = track_features.withColumn("tid", F.col("tid") + 1)
track_features_shift.createOrReplaceTempView('track_features_shift')

In [None]:
# For each playlist, create an array of all track ids in that playlist
pl_seqs = (edges_shift
         .groupby('pid')
         .agg(F.array_distinct(F.collect_list('tid')).alias('tids'))
         .filter(F.size('tids') >= 2)
)

In [None]:
# Building the length = 20, arrays for the tids, and create a masking object
PAD_ID = 0
max_length = 20


pairs = (
    pl_seqs
    # 1) Pick a positive by shuffling the tids per playlist, then take the first
    .withColumn('shuffle', F.shuffle(F.col('tids')))
    .withColumn('pos_tid', F.element_at(F.col('shuffle'), 1))
    # 2) Remove the positive from the context pool
    .withColumn('remain', F.filter('tids', lambda x: x != F.col('pos_tid')))
    # 3) Take up the max_length random items, shuffle then slice
    .withColumn('items', F.slice(F.shuffle(F.col('remain')), 1, max_length))
    # 4) Build mask & padding to max_length
    .withColumn('len', F.size('items'))
    .withColumn('pad_len', F.greatest(F.lit(0), F.lit(max_length) - F.col('len')))
    .withColumn('tokens', F.concat(F.col('items'), F.array_repeat(F.lit(PAD_ID), F.col('pad_len'))))
    .withColumn('mask', F.concat(F.array_repeat(F.lit(1), F.col('len')), F.array_repeat(F.lit(0), F.col('pad_len'))))
    # 5) Select relevant items
    .select('pid', 'tokens', 'mask', 'pos_tid')
)

In [None]:
# Creates a random 95/5 train/validation split
bucketed = pairs.withColumn('bucket', F.pmod(F.abs(F.hash('pid')), F.lit(100))) # Creates a random positive integer & modulus divides by 100. Essentially randomly groups each pid into 100 buckets
train_pairs = bucketed.filter('bucket < 95').select('pid', 'tokens', 'mask', 'pos_tid') # ~95% of the data
val_pairs = bucketed.filter('bucket >= 95').select('pid', 'tokens', 'mask', 'pos_tid') # ~5% of the data

In [None]:
# Validate train/val split
print("train rows:", train_pairs.count())
print("val rows:",   val_pairs.count())
n_tracks = edges_shift.agg(F.max('tid').alias('max_tid')).collect()[0]['max_tid']
print('n_tracks (embedding size):', int(n_tracks) + 1)

In [None]:
import pyarrow.dataset as ds
import torch

PARQUET_TRAIN = 'parquet_data/train_pairs'
PARQUET_VAL = 'parquet_data/val_pairs'

n_tracks = 2262108 
max_length = 20
embed_dim = 128
batch_size = 512 # raise if have GPU RAM
# Looks for the best GPUs to train on
DEVICE = (
    'cuda' if torch.cuda.is_available()
    else ('mps' if torch.backends.mps.is_available() else 'cpu')
)

class PlaylistPairDataset(torch.utils.data.Dataset):
    def __init__(self, parquet_path):
        table = ds.dataset(parquet_path, format='parquet').to_table(columns=['tokens', 'mask', 'pos_tid'])
        self.tokens = np.stack(table['tokens'].to_pylist()).astype(np.int64)
        self.mask = np.stack(table['mask'].to_pylist()).astype(np.bool_)
        self.pos = np.asarray(table['pos_tid']).astype(np.int64)
    
    def __len__(self):
        return self.pos.shape[0]
    
    def __getitem__(self, i):
        return {
            'tokens': torch.from_numpy(self.tokens[i]),
            'mask': torch.from_numpy(self.mask[i]),
            'pos': torch.tensor(self.pos[i], dtype=torch.long)
        }

def make_loader(path, batch_size=batch_size, shuffle=True, num_workers=0):
    ds_ = PlaylistPairDataset(path)
    return ds_, torch.utils.data.DataLoader(
        ds_, 
        batch_size = batch_size, 
        shuffle=shuffle, 
        num_workers=num_workers,
        pin_memory=(DEVICE=='cuda'), 
        drop_last=True
    )

train_ds, train_loader = make_loader(PARQUET_TRAIN, shuffle=True, num_workers=0)
val_ds, val_loader = make_loader(PARQUET_VAL, shuffle=False, num_workers=0)

In [None]:
# Validate tensors
b = next(iter(train_loader))
print(b['tokens'].shape, b['mask'].shape, b['pos'].shape)

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class TwoTower(nn.Module):
    def __init__(self, n_tracks, embed_dim=128, mlp_hidden=256, share_embed=True):
        super().__init__()
        self.item_emb = nn.Embedding(n_tracks, embed_dim, padding_idx=0)
        self.user_emb = self.item_emb if share_embed else nn.Embedding(n_tracks, embed_dim, padding_idx=0)
        self.item_mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_hidden), nn.ReLU(), nn.Linear(mlp_hidden, embed_dim)
        )
        self.user_mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_hidden), nn.ReLU(), nn.Linear(mlp_hidden, embed_dim)
        )
        self.tau = nn.Parameter(torch.tensor(1.0))
        
    def playlist_forward(self, tokens, mask):
        e = self.user_emb(tokens)
        e = e * mask.unsqueeze(-1)
        denom = mask.sum(dim=1, keepdims=True).clamp_min(1.0)
        pooled = e.sum(dim=1) / denom
        return F.normalize(self.user_mlp(pooled), dim=-1)
    def track_forward(self, pos_ids):
        t = self.item_emb(pos_ids)
        return F.normalize(self.item_mlp(t), dim=-1)
    def forward(self, tokens, mask, pos):
        p = self.playlist_forward(tokens, mask)
        t = self.track_forward(pos)
        logits = (p @ t.T) / self.tau.clamp_min(.001)
        return logits
    
model = TwoTower(n_tracks, embed_dim, mlp_hidden=256, share_embed=True).to(DEVICE)
opt = torch.optim.AdamW(model.parameters(), lr=.0003, weight_decay=.0001)
scaler = torch.amp.GradScaler() if DEVICE=='cuda' else None
    
def train_one_epoch():
    model.train(); running=0.0
    for batch in train_loader:
        tokens = batch['tokens'].to(DEVICE)
        mask = batch['mask'].to(DEVICE)
        pos = batch['pos'].to(DEVICE)
        opt.zero_grad(set_to_none=True)
        if scaler:
            with torch.amp.autocast():
                logits =  model(tokens, mask, pos)
                targets = torch.arange(logits.size(0), device=logits.device)
                loss = F.cross_entropy(logits, targets)
            scaler.scale(loss).backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(opt); scaler.update()
        else:
            logits = model(tokens, mask, pos)
            targets = torch.arange(logits.size(0), device=logits.device)
            loss = F.cross_entropy(logits, targets)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
        running += loss.item()
    return running / max(1, len(train_loader))

def recall_at_k(loader, K=10, pool=2000):
    model.eval()
    cand = []
    for i, b in enumerate(loader):
        cand.extend(b['pos'].tolist())
        if len(cand) >= pool:
            break
    cand = torch.tensor(cand[:pool], device=DEVICE)
    
    hits = n = 0
    for b in loader:
        tokens = b['tokens'].to(DEVICE) 
        mask = b['mask'].to(DEVICE)
        pos = b['pos'].to(DEVICE)
        p = model.playlist_forward(tokens, mask)
        batch = pos.size(0)
        pool_ids = cand.unsqueeze(0).repeat(batch, 1)
        pos = pool_ids[:, 0]
        t = model.track_forward(pool_ids.reshape(-1)).reshape(batch, -1, p.size(1))
        scores = (p.unsqueeze(1) * t).sum(-1)
        topk = scores.topk(K, dim=1).indices
        hits += (topk[:, 0] == 0).sum().item()
        n += batch
    return hits / max(1, n)

epochs = 3
best = 0.0
for ep in range(1, epochs+1):
    tr = train_one_epoch()
    r10 = recall_at_k(val_loader, K=10, pool=2000)
    print(f'epoch {ep}: loss={tr:.4f}, val@10={r10:.4f}')
    if r10 > best:
        best = r10
        torch.save({'state_dict': model.state_dict(),
                    'n_tracks': n_tracks,
                    'embed_dim': embed_dim},
                   'two_tower_best.pt')
print('best val@10', best)

In [None]:
# Sanity checks for counts of pl_seq
print('playlists (>= 2 tracks):', pl_seqs.count())
size_hist = (
    pl_seqs
    .withColumn('n_tracks', F.size('tids'))
    .groupBy('n_tracks').count()
    .orderBy('n_tracks')
)
mins_maxs = edges_shift.agg(F.min('tid').alias('min_tid'), F.max('tid').alias('max_tid')).collect()[0]
print("tid range:", mins_maxs["min_tid"], "to", mins_maxs["max_tid"])
size_hist.show(10)

In [None]:
spark.sql("""
          SELECT *
          FROM playlist_tracks
          ORDER BY pid
          LIMIT 5
          """).show()

In [None]:
spark.sql("""
          SELECT *
          FROM edges_shift
          ORDER BY tid ASC
          LIMIT 5
          """).show()


In [None]:
spark.sql("""
          SELECT *
          FROM track_features_shift
          ORDER BY tid ASC
          LIMIT 5
          """).show()

In [None]:
artist_features.show(5)

In [None]:
playlist_features.show(5)

In [None]:
track_features.show(5)

In [None]:
# Verification
print(edges.select('pid').distinct().count())
print(edges.select('tid').distinct().count())

In [None]:
edges.show(5)

In [None]:
# This is essentially InfoNCE with one positive and multiple negatives per query
def info_nce_loss(query, positive, negatives, temperature=0.07):
    # query: [batch, dim]
    # positive: [batch, dim]
    # negatives: [batch, num_neg, dim]
    
    # Compute similarities
    pos_sim = torch.sum(query * positive, dim=-1, keepdim=True) / temperature
    neg_sim = torch.bmm(negatives, query.unsqueeze(-1)).squeeze(-1) / temperature  # [batch, num_neg]
    
    logits = torch.cat([pos_sim, neg_sim], dim=1)
    labels = torch.zeros(query.size(0), dtype=torch.long, device=query.device)
    
    return F.cross_entropy(logits, labels)


In [None]:
tracks.printSchema()