In [53]:
import sys, os.path as op
sys.path.insert(0, '/home/yon/jupyter-server/mlg/src/')

In [54]:
source_dir = "spotify_million_playlist_dataset/pickles"
!ls spotify_million_playlist_dataset/pickles

carloss62.pkl	      G_example.pkl	      top-G-1000.pkl
carloss66.pkl	      ghetero_example.pkl     top-G-5000_example.pkl
datasets_example.pkl  ghetero.pkl	      top-ghetero-1000.pkl
datasets.pkl	      G.pkl		      top-idx-1000.pkl
deleteme.pkl	      top-G-1000_example.pkl


In [55]:
import pickle
from model import *
graph = pickle.load(open(op.join(source_dir, "top-ghetero-5000.pkl"), "rb"))
model = pickle.load(open(op.join(source_dir, "carloss62.pkl"), "rb"))
name_to_id = pickle.load(open(op.join(source_dir, "top-idx-5000.pkl"), "rb"))
id_to_name = {T: {v : k for k, v in dictionary.items()} for T, dictionary in name_to_id.items()}

In [68]:
import torch

def count_unique_connections(graph, edge_type, position=1):
    return len(torch.unique(graph[edge_type].edge_index[position, :], dim=0))

def add_playlist(graph, playlist):
    # Set new playlist id
    playlist_id = len(graph["playlist"].x)

    # Expected playlist parameters
    key_defaults = {
        "num_followers": 0,
        "collaborative": False,
        "num_edits": 0,
        "tracks": []
    }

    # Set default values for missing parameters
    playlist = {**key_defaults, **playlist}

    # Count unique elements
    playlist["num_tracks"] = len(playlist["tracks"])
    playlist["num_artists"] = count_unique_connections(graph, ("track", "authors", "artist"))
    playlist["num_albums"] = count_unique_connections(graph, ("track", "includes", "album"))

    # Sum durations
    playlist["duration_ms"] = sum([graph["track"].x[track_id][0] for track_id in playlist["tracks"]])

    # Extract playlist embedding
    playlist_embedding_array = [playlist["num_followers"], playlist["collaborative"], playlist["num_albums"], playlist["num_tracks"], playlist["num_edits"], playlist["duration_ms"], playlist["num_artists"]]
    playlist_embedding = torch.FloatTensor(playlist_embedding_array).reshape(1, -1)

    # Add playlist to graph
    graph["playlist"].x = torch.cat((graph["playlist"].x, playlist_embedding), dim=0)

    # Add track to playlist connections
    for track_id in playlist["tracks"]:
        graph["track", "contains", "playlist"].edge_index = torch.cat((graph["track", "contains", "playlist"].edge_index, torch.LongTensor([[track_id], [playlist_id]])), dim=1)

        track_artist_index = graph["track", "authors", "artist"].edge_index[0, :].eq(track_id).nonzero()
        if track_artist_index.size(0) == 0:
            raise Exception("Track {} has no artist".format(track_id))
        
        artist_id = graph["track", "authors", "artist"].edge_index[1, track_artist_index.item()]

        track_album_index = graph["track", "includes", "album"].edge_index[0, :].eq(track_id).nonzero()
        if track_album_index.size(0) == 0:
            raise Exception("Track {} has no album".format(track_id))
        album_id = graph["track", "includes", "album"].edge_index[1, track_album_index.item()]

        graph["track", "authors", "artist"].edge_index = torch.cat((graph["track", "authors", "artist"].edge_index, torch.LongTensor([[track_id], [artist_id]])), dim=1)
        graph["track", "includes", "album"].edge_index = torch.cat((graph["track", "includes", "album"].edge_index, torch.LongTensor([[track_id], [album_id]])), dim=1)

    # Construct index of all track connections to new playlist for prediction
    track_ids = torch.LongTensor([i for i in range(len(graph["track"].x))])
    playlist_ids = torch.LongTensor([playlist_id] * len(graph["track"].x))
    new_playlist_tracks_edge_index = torch.cat((track_ids.reshape(1, -1), playlist_ids.reshape(1, -1)), dim=0)
    graph["track", "contains", "playlist"].edge_label_index = new_playlist_tracks_edge_index

In [69]:
# Set playlist tracks by name
tracks_names = ["spotify:track:7N3PAbqfTjSEU1edb2tY8j"]
track_ids = [name_to_id["track"][name] for name in tracks_names]

# Add playlist and set track connections for prediction
graph.to("cpu")
add_playlist(graph, {
    "collaborative": False,
    "num_edits": 0,
    "num_followers": 0,
    "tracks": track_ids
})

Exception: Track 35806 has no artist

In [None]:
# predict track connections
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
graph = graph.to(device)
with torch.no_grad():
    pred = model(graph)
    
    # get top 10 predictions
    most_likely = torch.topk(pred, 10, dim=0)

In [None]:
# Most likely track id prediction
most_likely

torch.return_types.topk(
values=tensor([-0.2562, -0.2562, -0.2562, -0.2562, -0.2562, -0.2562, -0.2562, -0.2563,
        -0.2563, -0.2564], device='cuda:0'),
indices=tensor([2603,  978, 1184, 1084, 1259, 1007, 4003, 8128, 8961, 9201],
       device='cuda:0'))

In [49]:
id_to_name

{'playlist': {0: 'spotify:playlist:489745',
  1: 'spotify:playlist:363112',
  2: 'spotify:playlist:551199',
  3: 'spotify:playlist:632384',
  4: 'spotify:playlist:881739',
  5: 'spotify:playlist:881244',
  6: 'spotify:playlist:848402',
  7: 'spotify:playlist:721448',
  8: 'spotify:playlist:652590',
  9: 'spotify:playlist:207874',
  10: 'spotify:playlist:700585',
  11: 'spotify:playlist:540716',
  12: 'spotify:playlist:647824',
  13: 'spotify:playlist:866032',
  14: 'spotify:playlist:900034',
  15: 'spotify:playlist:859628',
  16: 'spotify:playlist:168195',
  17: 'spotify:playlist:107343',
  18: 'spotify:playlist:879085',
  19: 'spotify:playlist:534865',
  20: 'spotify:playlist:28143',
  21: 'spotify:playlist:173508',
  22: 'spotify:playlist:261423',
  23: 'spotify:playlist:261519',
  24: 'spotify:playlist:1750',
  25: 'spotify:playlist:838728',
  26: 'spotify:playlist:210686',
  27: 'spotify:playlist:332559',
  28: 'spotify:playlist:852038',
  29: 'spotify:playlist:716271',
  30: 'spot

In [None]:
# Predicted track names
[id_to_name["track"][i.item()] for i in most_likely.indices]

['spotify:track:1pfSKTDigM2xnWG2Y0AquA',
 'spotify:track:1kWqRB6PHLM9QTfaXuHpIV',
 'spotify:track:7hjJKN8FicWf2uq39ASi2e',
 'spotify:track:26sqM0uf0Ka7FDHkdRfbqh',
 'spotify:track:2YJgM8lj4u5YEZRSHSfL60',
 'spotify:track:4ZFNuJ9eD41EuqITQBYPWG',
 'spotify:track:1MVHp4OB3ZFtJp6QiLUUOB',
 'spotify:track:6z9cUmD9oGGmHquxadVHkz',
 'spotify:track:1qJZ4aYAUeKkjDpjhb9TuN',
 'spotify:track:0wRCXAjEC4yr8ghGTqX6OB']