In [45]:
import pandas as pd
import numpy as np
from lightfm import LightFM
from scipy import sparse
from time import time
from lightfm.evaluation import precision_at_k
from tqdm import tqdm

In [2]:
interactions_train = pd.read_parquet('/scratch/work/courses/DSGA1004-2021/listenbrainz/interactions_train_small.parquet')
interactions_train = interactions_train.drop(columns='timestamp')
interactions_test = pd.read_parquet('/scratch/work/courses/DSGA1004-2021/listenbrainz/interactions_test.parquet')
interactions_test = interactions_test.drop(columns='timestamp')

In [3]:
# Filter users with at least 50 interactions
interactions_count = interactions_train.groupby('user_id').size().reset_index(name='num_interactions')
selected_user_ids = interactions_count[interactions_count.num_interactions >= 50]['user_id'].values
train_filtered = interactions_train[interactions_train.user_id.isin(selected_user_ids)]

In [4]:
# Map recording_msid to track_id
unique_msid = pd.concat([train_filtered['recording_msid'], interactions_test['recording_msid']]).unique()
df_trackid = pd.DataFrame({'recording_msid': unique_msid, 'track_id': np.arange(len(unique_msid))})

In [5]:
# Create R matrix
train_R_msid = train_filtered.groupby(['user_id', 'recording_msid']).size().reset_index(name='count')
test_R_msid = interactions_test.groupby(['user_id', 'recording_msid']).size().reset_index(name='count')

In [6]:
# Add track_id
train_R = pd.merge(train_R_msid, df_trackid, on='recording_msid')[['user_id', 'track_id', 'count']]
test_R = pd.merge(test_R_msid, df_trackid, on='recording_msid')[['user_id', 'track_id', 'count']]

In [7]:
# Create true labels for test and save test
test_true = test_R.groupby('user_id')['track_id'].apply(set).reset_index(name='true')
test_true_sorted = test_true.sort_values('user_id')

In [8]:
# Sort train
train_R_sorted = train_R.sort_values(['user_id', 'track_id'])

In [9]:
train_R_sorted.to_parquet('train_inter_small.parquet')
test_true_sorted.to_parquet('test_true.parquet')

In [10]:
train_R_sorted.to_csv('train_inter_small.csv', index = False)
test_true_sorted.to_csv('test_true.csv', index = False)

In [55]:
# Read data
train = pd.read_parquet('train_inter_small.parquet')
train = train.sample(frac=0.01, random_state=42)
test_true = pd.read_parquet('test_true.parquet')

In [56]:
# Create user and item id mappings
user_ids = train['user_id'].unique()
item_ids = train['track_id'].unique()

In [57]:
user_to_index = {original: idx for idx, original in enumerate(user_ids)}
index_to_user = {idx: original for original, idx in user_to_index.items()}
item_to_index = {original: idx for idx, original in enumerate(item_ids)}
index_to_item = {idx: original for original, idx in item_to_index.items()}

In [58]:
# Create sparse matrix for training data
rows = train['user_id'].map(user_to_index)
cols = train['track_id'].map(item_to_index)
interactions = sparse.coo_matrix((train['count'], (rows, cols)), shape=(len(user_to_index), len(item_to_index)))

In [59]:
# LightFM model
model = LightFM(loss='warp')
start_time = time()
model.fit(interactions, epochs=30, num_threads = 16)

<lightfm.lightfm.LightFM at 0x148869ed3880>

In [61]:
# Prepare test set
test_true = test_true[test_true['user_id'].isin(user_to_index.keys())]
test_true = test_true.reset_index(drop=True)  # reset the index after filtering
test_user_ids = test_true['user_id'].map(user_to_index).values.astype(int)
test_track_ids = [set(test_true['true'][i]) for i in range(len(test_true))]

In [62]:
# Compute MAP@K for test set
mapk_test = []
for user_id, true_tracks in tqdm(zip(test_user_ids, test_track_ids)):
    # repeat the user id for the number of items
    user_ids_array = np.full(interactions.shape[1], user_id, dtype=np.int32)
    item_ids_array = np.arange(interactions.shape[1], dtype=np.int32)
    scores = model.predict(user_ids_array, item_ids_array)
    top_items = np.argsort(-scores)[:100]  # get top 100 items
    pred_tracks = set(top_items)
    mapk_test.append(len(true_tracks & pred_tracks) / len(pred_tracks))

5016it [02:14, 37.36it/s]


In [63]:
mapk_test = np.mean(mapk_test)
end_time = time()
lightfm_time = end_time - start_time
print(f'The total train time for lightFM is: {lightfm_time}')
print(f"MAP@K for test set: {mapk_test}")

The total train time for lightFM is: 155.69324612617493
MAP@K for test set: 0.0012141148325358854
