In [26]:
import sys, os.path as op
sys.path.insert(0, '/home/yon/jupyter-server/mlg/src/')

In [27]:
source_dir = "spotify_million_playlist_dataset/pickles"
!ls spotify_million_playlist_dataset/pickles

carloss62.pkl	      G_example.pkl	      top-G-5000.pkl
carloss66.pkl	      ghetero_example.pkl     top-ghetero-1000.pkl
carloss72.pkl	      ghetero.pkl	      top-ghetero-5000-fixed-maybe.pkl
datasets_example.pkl  G.pkl		      top-ghetero-5000.pkl
datasets.pkl	      top-G-1000_example.pkl  top-idx-1000.pkl
deleteme.pkl	      top-G-1000.pkl	      top-idx-5000.pkl


In [28]:
import pickle
from model import *
graph = pickle.load(open(op.join(source_dir, "top-ghetero-5000-fixed-maybe.pkl"), "rb"))
model = pickle.load(open(op.join(source_dir, "carloss72.pkl"), "rb"))
name_to_id = pickle.load(open(op.join(source_dir, "top-idx-5000.pkl"), "rb"))
id_to_name = {T: {v : k for k, v in dictionary.items()} for T, dictionary in name_to_id.items()}

In [29]:
import torch

def count_unique_connections(graph, edge_type, position=1):
    return len(torch.unique(graph[edge_type].edge_index[position, :], dim=0))

def add_playlist(graph, playlist):
    # Set new playlist id
    playlist_id = len(graph["playlist"].x)

    # Expected playlist parameters
    key_defaults = {
        "num_followers": 0,
        "collaborative": False,
        "num_edits": 0,
        "tracks": []
    }

    # Set default values for missing parameters
    playlist = {**key_defaults, **playlist}

    # Count unique elements
    playlist["num_tracks"] = len(playlist["tracks"])
    playlist["num_artists"] = count_unique_connections(graph, ("track", "authors", "artist"))
    playlist["num_albums"] = count_unique_connections(graph, ("track", "includes", "album"))

    # Sum durations
    playlist["duration_ms"] = sum([graph["track"].x[track_id][0] for track_id in playlist["tracks"]])

    # Extract playlist embedding
    playlist_embedding_array = [playlist["num_followers"], playlist["collaborative"], playlist["num_albums"], playlist["num_tracks"], playlist["num_edits"], playlist["duration_ms"], playlist["num_artists"]]
    playlist_embedding = torch.FloatTensor(playlist_embedding_array).reshape(1, -1)

    # Add playlist to graph
    graph["playlist"].x = torch.cat((graph["playlist"].x, playlist_embedding), dim=0)

    # Add track to playlist connections
    for track_id in playlist["tracks"]:
        graph["track", "contains", "playlist"].edge_index = torch.cat((graph["track", "contains", "playlist"].edge_index, torch.LongTensor([[track_id], [playlist_id]])), dim=1)

        track_artist_index = graph["track", "authors", "artist"].edge_index[0, :].eq(track_id).nonzero()
        if track_artist_index.size(0) == 0:
            raise Exception("Track {} has no artist".format(track_id))
        artist_id = graph["track", "authors", "artist"].edge_index[1, track_artist_index]
        artist_id = artist_id if artist_id.dim() == 0 else artist_id[0]
        

        track_album_index = graph["track", "includes", "album"].edge_index[0, :].eq(track_id).nonzero()
        if track_album_index.size(0) == 0:
            raise Exception("Track {} has no album".format(track_id))
        album_id = graph["track", "includes", "album"].edge_index[1, track_album_index]
        album_id = album_id if album_id.dim() == 0 else album_id[0]

        graph["track", "authors", "artist"].edge_index = torch.cat((graph["track", "authors", "artist"].edge_index, torch.LongTensor([[track_id], [artist_id]])), dim=1)
        graph["track", "includes", "album"].edge_index = torch.cat((graph["track", "includes", "album"].edge_index, torch.LongTensor([[track_id], [album_id]])), dim=1)

    # Construct index of all track connections to new playlist for prediction
    track_ids = torch.LongTensor([i for i in range(len(graph["track"].x))])
    playlist_ids = torch.LongTensor([playlist_id] * len(graph["track"].x))
    new_playlist_tracks_edge_index = torch.cat((track_ids.reshape(1, -1), playlist_ids.reshape(1, -1)), dim=0)
    graph["track", "contains", "playlist"].edge_label_index = new_playlist_tracks_edge_index

In [30]:
# Set playlist tracks by name
tracks_names = [
    "spotify:track:7N3PAbqfTjSEU1edb2tY8j", "spotify:track:7o2CTH4ctstm8TNelqjb51", "spotify:track:7o2CTH4ctstm8TNelqjb51", 
    "spotify:track:70Z9t1qhytWtG4cCmmi7mU",             
                ]
track_ids = [name_to_id["track"][name] for name in tracks_names]

# Add playlist and set track connections for prediction
graph.to("cpu")
add_playlist(graph, {
    "collaborative": False,
    "num_edits": 0,
    "num_followers": 0,
    "tracks": track_ids
})

In [31]:
# predict track connections
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
graph = graph.to(device)
with torch.no_grad():
    pred = model(graph)
    
    # get top 10 predictions
    most_likely = torch.topk(pred, 10, dim=0)

In [32]:
# Most likely track id prediction
most_likely

torch.return_types.topk(
values=tensor([0.7601, 0.7575, 0.7522, 0.7521, 0.7511, 0.7510, 0.7509, 0.7508, 0.7508,
        0.7507], device='cuda:0'),
indices=tensor([ 96845, 100513,  95372,  56272,  66383,  37171,  91803, 118145,  29110,
        101717], device='cuda:0'))

In [33]:
# Predicted track names
[id_to_name["track"][i.item()] for i in most_likely.indices]

['spotify:track:5tl8usOmZVss5gtf1bDQy6',
 'spotify:track:1ZPCfXdkN4UQkbQRnh9yJW',
 'spotify:track:08hNyQHnhEpEbm3DZNbnH0',
 'spotify:track:43wxfOT7MnPKC306eD35HC',
 'spotify:track:5mPSyjLatqB00IkPqRlbTE',
 'spotify:track:6mghCOaaSvrke0z1EUVUIf',
 'spotify:track:6BbINUfGabVyiNFJpQXn3x',
 'spotify:track:6RcQOut9fWL6FSqeIr5M1r',
 'spotify:track:1Ser4X0TKttOvo8bgdytTP',
 'spotify:track:2q4rjDy9WhaN3o9MvDbO21']