Given User data, generate query embeddings, then search for top k candidate embeddings. With a holdout validation data, we can compare and see if top k results captured the next track the user actually listened (recall@k)

In [1]:
import tower
from datasets import load_from_disk
import os

# load validation dataset (already processed for simplicity)
filepath = '../../datasets/yambda/processed'
val_processed = load_from_disk(os.path.join(filepath, "val"))

val_dataset = tower.MusicWindowDataset(data = val_processed, 
                                   n = 10, 
                                   global_t_max = 0, # does not really matter, will be set to 0 in forward pass
                                   max_windows_per_user=1000)

  from .autonotebook import tqdm as notebook_tqdm


Generating samples from sequences...
Generated 1352980 total samples.


In [2]:
# load query model
from tower import QueryTower
import os
import torch

len_unique_users = 9238
len_unique_items = 877168
len_unique_albums = 3367691
len_unique_artists = 1293394

user_embed_size = 14
item_embed_size = 20
album_embed_size = 22
artist_embed_size = 21



query_input_size = user_embed_size + 2*item_embed_size + 2*artist_embed_size + 2*album_embed_size + 3
candidate_input_size = item_embed_size + artist_embed_size + album_embed_size + 1

# initialize embeddings
item_embed = torch.nn.Embedding(num_embeddings=len_unique_items+1, embedding_dim=item_embed_size)
album_embed = torch.nn.Embedding(num_embeddings=len_unique_albums+1, embedding_dim=album_embed_size)
artist_embed = torch.nn.Embedding(num_embeddings=len_unique_artists+1, embedding_dim=artist_embed_size)

# initialize and test query tower
query_model = QueryTower(input_size = query_input_size,
                         hidden_size = [1024, 512, 128],
                         user_num_embeddings = len_unique_users+1,
                         user_embed_size = user_embed_size,
                         item_embed = item_embed,
                         artist_embed = artist_embed,
                         album_embed = album_embed,
                         log_age_mean = 15.874020,
                         log_age_std = 1.090574
                        )


# load last checkpoint, loads saved weights for embeddings as well as the query hidden layers
checkpoints = "./checkpoints"
checkpoint = torch.load(os.path.join(checkpoints, f"last_checkpoint.pth"))
query_model.load_state_dict(checkpoint["query_model_state_dict"])

<All keys matched successfully>

In [3]:
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np

device = "mps"
B = 8192
query_model.to(device)
query_model.eval()

query_dataloader = DataLoader(val_dataset, batch_size=B, num_workers=8)
all_val_query_embeddings = []
all_val_target_ids = []

print("Generating embeddings for the Validation set...")
with torch.no_grad():
    for query, item, _ in tqdm(query_dataloader):
        # Move batch to the correct device
        batch = {k: v.to(device) for k, v in query.items()}
        
        # Generate embeddings
        query_embeddings = query_model(batch)        
        # Normalize embeddings (important for Faiss with inner product)
        query_embeddings = torch.nn.functional.normalize(query_embeddings, p=2, dim=1)
        
        # Move to CPU and convert to numpy
        all_val_query_embeddings.append(query_embeddings.cpu().numpy())
        all_val_target_ids.append(item["candidate_id"].numpy())

val_query_embeddings_np = np.concatenate(all_val_query_embeddings, axis=0)
val_target_np = np.concatenate(all_val_target_ids, axis=0)

Generating embeddings for the Validation set...


100%|██████████| 168/168 [03:08<00:00,  1.12s/it]


In [4]:
print(val_query_embeddings_np.shape, val_target_np.shape)

(1352980, 128) (1352980,)


In [5]:
# load index table

import faiss
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"

In [45]:
index = faiss.read_index("baseline.index") # baseline knn
# sample_size = len(val_query_embeddings_np)
sample_size = 4000
for k in [5, 10, 30, 50, 100, 300, 500, 1000]:
    # search top k    
    _, I = index.search(val_query_embeddings_np[:sample_size], k)

    # unsqueeze target array to (sample, 1) then broadcast comparison
    hits = (I == val_target_np[:sample_size][:, None])
    recall_k = hits.any(axis=1).mean() # capture any True in a row, then take mean of every row
    print(f"recall@{k}: {recall_k}")

recall@5: 0.012
recall@10: 0.01925
recall@30: 0.0405
recall@50: 0.05225
recall@100: 0.07375
recall@300: 0.12575
recall@500: 0.166
recall@1000: 0.229


In [46]:
index = faiss.read_index("items.index") # IVFPQ
sample_size = 4000
for k in [5, 10, 30, 50, 100, 300, 500, 1000]:
    # search top k    
    _, I = index.search(val_query_embeddings_np[:sample_size], k)

    # unsqueeze target array to (sample, 1) then broadcast comparison
    hits = (I == val_target_np[:sample_size][:, None])
    recall_k = hits.any(axis=1).mean() # capture any True in a row, then take mean of every row
    print(f"recall@{k}: {recall_k}")

recall@5: 0.005
recall@10: 0.00875
recall@30: 0.0185
recall@50: 0.0235
recall@100: 0.03375
recall@300: 0.0585
recall@500: 0.07725
recall@1000: 0.114
