In [1]:
import numpy as np
import pandas as pd

from gensim.models import Word2Vec



### Prepare data

In [2]:
df = pd.read_csv("../data/rec_test_assignment_playlist2track.csv")

In [3]:
df.head()

Unnamed: 0,playlist_id,track_id,track_uri
0,0,0,spotify:track:1r0faljjM2b876iNoaDUh5
1,1,1,spotify:track:1JO1xLtVc8mWhIoE3YaCL0
2,2,2,spotify:track:3SuzCeGoNOWmbWOoDMou0B
3,3,3,spotify:track:4GJAd1nBylAEbw1dZDVzEQ
4,4,4,spotify:track:5cbpoIu3YjoOwbBDGUEp3P


In [4]:
n_track = df['track_id'].unique().shape[0]
n_playlist = df['playlist_id'].unique().shape[0]
print('tracks: {}, playlists: {}'.format(n_track, n_playlist))

tracks: 169548, playlists: 714818


In [5]:
N_TOP = 10

playlists_df = df.groupby('playlist_id')['track_id'].agg(neighbours=list)
playlists_df['n_neighbours'] = playlists_df['neighbours'].apply(len)
playlists_df = playlists_df[playlists_df['n_neighbours'] >= N_TOP].drop('n_neighbours', axis=1)

playlists = playlists_df['neighbours'].to_list()

In [6]:
len(playlists)

9107

In [7]:
df = df.merge(
        playlists_df,
        left_on='playlist_id',
        right_index=True
    )

df['neighbours'] = df.apply(lambda r: [n for n in r['neighbours'] if n != r['track_id']], axis=1)

### Train and metrics

In [8]:
model = Word2Vec(playlists, min_count=1, vector_size=600, epochs=20)

In [9]:
def search_top(id, topn):
    return [item[0] for item in model.wv.most_similar(id, topn=topn)]


df['neighbours_pred'] = df['track_id'].apply(lambda x: search_top(x, N_TOP))

In [10]:
def precision_at_k(true, pred):
    return len(set(pred) & set(true)) / len(pred)


def recall_at_k(true, pred):
    return len(set(pred) & set(true)) / len(true)


triplet_precision = df.apply(lambda r: precision_at_k(r['neighbours'], r['neighbours_pred']), axis=1).mean()
triplet_recall = df.apply(lambda r: recall_at_k(r['neighbours'], r['neighbours_pred']), axis=1).mean()

In [11]:
print("Triplet precision: {:.6f}".format(triplet_precision))
print("Triplet recall: {:.6f}".format(triplet_recall))

Triplet precision: 0.007279
Triplet recall: 0.007137
