In [1]:
import sys, os.path as op
sys.path.insert(0, '/home/yon/jupyter-server/mlg/src/')

In [2]:
source_dir = "spotify_million_playlist_dataset/pickles"
!ls spotify_million_playlist_dataset/pickles

carloss62.pkl	      ghetero_example.pkl     top-ghetero-1000.pkl
carloss66.pkl	      ghetero.pkl	      top-ghetero-5000-fixed-maybe.pkl
carloss72.pkl	      G.pkl		      top-ghetero-5000.pkl
datasets_example.pkl  index.pkl		      top-idx-1000.pkl
datasets.pkl	      top-G-1000_example.pkl  top-idx-5000.pkl
deleteme.pkl	      top-G-1000.pkl
G_example.pkl	      top-G-5000.pkl


In [3]:
import pickle
from model import *
graph = pickle.load(open(op.join(source_dir, "top-ghetero-5000.pkl"), "rb"))
model = pickle.load(open(op.join(source_dir, "carloss72.pkl"), "rb"))
name_to_id = pickle.load(open(op.join(source_dir, "top-idx-5000.pkl"), "rb"))
id_to_name = {T: {v : k for k, v in dictionary.items()} for T, dictionary in name_to_id.items()}
index = pickle.load(open(op.join(source_dir, "index.pkl"), "rb"))

In [4]:
import torch

def count_unique_connections(graph, edge_type, position=1):
    return len(torch.unique(graph[edge_type].edge_index[position, :], dim=0))

def add_playlist(graph, playlist, exceptions=True):
    # Set new playlist id
    playlist_id = len(graph["playlist"].x)

    # Expected playlist parameters
    key_defaults = {
        "num_followers": 0,
        "collaborative": False,
        "num_edits": 0,
        "tracks": []
    }

    # Set default values for missing parameters
    playlist = {**key_defaults, **playlist}

    # Count unique elements
    playlist["num_tracks"] = len(playlist["tracks"])
    playlist["num_artists"] = count_unique_connections(graph, ("track", "authors", "artist"))
    playlist["num_albums"] = count_unique_connections(graph, ("track", "includes", "album"))

    # Sum durations
    playlist["duration_ms"] = sum([graph["track"].x[track_id][0] for track_id in playlist["tracks"]])

    # Extract playlist embedding
    playlist_embedding_array = [playlist["num_followers"], playlist["collaborative"], playlist["num_albums"], playlist["num_tracks"], playlist["num_edits"], playlist["duration_ms"], playlist["num_artists"]]
    playlist_embedding = torch.FloatTensor(playlist_embedding_array).reshape(1, -1)

    # Add playlist to graph
    graph["playlist"].x = torch.cat((graph["playlist"].x, playlist_embedding), dim=0)

    # Add track to playlist connections
    for track_id in playlist["tracks"]:
        graph["track", "contains", "playlist"].edge_index = torch.cat((graph["track", "contains", "playlist"].edge_index, torch.LongTensor([[track_id], [playlist_id]])), dim=1)

        track_artist_index = graph["track", "authors", "artist"].edge_index[0, :].eq(track_id).nonzero()
        if track_artist_index.size(0) == 0:
            if exceptions:
                raise Exception("Track {} has no artist".format(track_id))
        else:
            artist_id = graph["track", "authors", "artist"].edge_index[1, track_artist_index]
            artist_id = artist_id if artist_id.dim() == 0 else artist_id[0]
            graph["track", "authors", "artist"].edge_index = torch.cat((graph["track", "authors", "artist"].edge_index, torch.LongTensor([[track_id], [artist_id]])), dim=1)
        

        track_album_index = graph["track", "includes", "album"].edge_index[0, :].eq(track_id).nonzero()
        if track_album_index.size(0) == 0:
            if exceptions:
                raise Exception("Track {} has no album".format(track_id))
        else:
            album_id = graph["track", "includes", "album"].edge_index[1, track_album_index]
            album_id = album_id if album_id.dim() == 0 else album_id[0]
            graph["track", "includes", "album"].edge_index = torch.cat((graph["track", "includes", "album"].edge_index, torch.LongTensor([[track_id], [album_id]])), dim=1)

    # Construct index of all track connections to new playlist for prediction
    track_ids = torch.LongTensor([i for i in range(len(graph["track"].x))])
    playlist_ids = torch.LongTensor([playlist_id] * len(graph["track"].x))
    new_playlist_tracks_edge_index = torch.cat((track_ids.reshape(1, -1), playlist_ids.reshape(1, -1)), dim=0)
    graph["track", "contains", "playlist"].edge_label_index = new_playlist_tracks_edge_index

In [5]:
# Set playlist tracks by name
tracks_names = ['spotify:track:5mMxriCaSYyZopQDPYkyqT',
  'spotify:track:0I9KnKdIVz8envRSA6yGZL',
  'spotify:track:0FWhGmPVxLI6jOVF0wjALa',
  'spotify:track:2yNWwardt8VzlpNBWrGYD6',
  'spotify:track:159wbQ5j11xBszXYFym04u',
  'spotify:track:1Jm1I3APcmVz3MqPr5vfTx',
  'spotify:track:3R97rNX7JnmshCWBwOSFet',
  'spotify:track:1AwJGxWNl5n8O2CSlvPKYL',
  'spotify:track:5GorFaKkP2mLREQvhSblIg',
  'spotify:track:29z6nESQsLBgtJkUXOJvGN',
  'spotify:track:3koCCeSaVUyrRo3N2gHrd8',
  'spotify:track:5lA3pwMkBdd24StM90QrNR',
  'spotify:track:5D9Nw6HyFH0k40X8RxHfD6',
  'spotify:track:1JClFT74TYSXlzpagbmj0S',
  'spotify:track:0DBIL8arX0Zo6eAuxNIpik',
  'spotify:track:2RqZFOLOnzVmHUX7ZMcaES',
  'spotify:track:4tVEJTR4VgBvvb2R6phA2v',
  'spotify:track:0sKlV58cODrjxGFOyf9IXY',
  'spotify:track:6b8Be6ljOzmkOmFslEb23P',
  'spotify:track:3M5eeHXgG4VplKNcsBC8Dj',
  'spotify:track:6X4JeTWCuKEzKOEHXDtyBo',
  'spotify:track:5DuTNKFEjJIySAyJH1yNDU',
  'spotify:track:3G6hxSp260RzGw4sOiDOQ3',
  'spotify:track:1aWV3uY3SIEZVbmv45oFWS',
  'spotify:track:6m59VvDUi0UQsB2eZ9wVbH',
  'spotify:track:1v0ufp7FLTFcykUGOmFZKa',
  'spotify:track:5W4vPDfwFNQqt7frRjL41t',
  'spotify:track:161DnLWsx1i3u1JT05lzqU',
  'spotify:track:0KKkJNfGyhkQ5aFogxQAPU',
  'spotify:track:5BLBpcQLF54Lxg3ufn1GCT',
  'spotify:track:7J41dYQolQJEtj3UmKLu5r',
  'spotify:track:3l3xTXsUXeWlkPqzMs7mPD',
  'spotify:track:2dCmGcEOQrMQhMMS8Vj7Ca',
  'spotify:track:5lXcSvHRVjQJ3LB2rLKQog',
  'spotify:track:5wiu4PUC6CLNNbC48vqGOb',
  'spotify:track:1OsCKwNZxph96EkNusILRy',
  'spotify:track:5ChkMS8OtdzJeqyybCc9R5',
  'spotify:track:3eRsSIhorhBnLrsy2uhM8r',
  'spotify:track:0nyrltZrQGAJMBZc1bYvuQ']
track_ids = [name_to_id["track"][name] for name in tracks_names if name in name_to_id["track"]]
print("Could not find {} tracks out of {}".format(len(tracks_names) - len(track_ids), len(tracks_names)))

# Add playlist and set track connections for prediction
graph.to("cpu")
add_playlist(graph, {
    "collaborative": False,
    "num_edits": 0,
    "num_followers": 0,
    "tracks": track_ids
}, exceptions=False)

Could not find 0 tracks out of 39


In [6]:
# predict track connections
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
graph = graph.to(device)
with torch.no_grad():
    pred = model(graph)
    
    # get top 10 predictions
    most_likely = torch.topk(pred, 10, dim=0)

    # get worst 10
    most_unlikely = torch.topk(pred, 10, dim=0, largest=False)

In [7]:
# Best
[(index["track"][id_to_name["track"][track_id.item()].split(":")[-1]], id_to_name["track"][track_id.item()]) for track_id in most_likely.indices]

[("It's A Vibe", 'spotify:track:6H0AwSQ20mo62jGlPGB8S6'),
 ('King Kunta', 'spotify:track:0N3W5peJUQtI4eyR6GJT5O'),
 ('Iris', 'spotify:track:6Qyc6fS4DsZjB2mRW9DsQs'),
 ('Work from Home', 'spotify:track:4tCtwWceOPWzenK2HAIJSb'),
 ('Fight Song', 'spotify:track:37f4ITSlgPX81ad2EvmVQr'),
 ('Hey Mama (feat. Nicki Minaj, Bebe Rexha & Afrojack)',
  'spotify:track:285HeuLxsngjFn4GGegGNm'),
 ('Hymn For The Weekend - Seeb Remix', 'spotify:track:1OAiWI2oPmglaOiv9fdioU'),
 ('Cheap Thrills', 'spotify:track:27SdWb2rFzO6GWiYDBTD9j'),
 ('Rude', 'spotify:track:6RtPijgfPKROxEzTHNRiDp'),
 ('All I Want', 'spotify:track:5JuA3wlm0kn7IHfbeHV0i6')]

In [8]:
# Worst
[(index["track"][id_to_name["track"][track_id.item()].split(":")[-1]], id_to_name["track"][track_id.item()]) for track_id in most_unlikely.indices]

[('Motets for 4 Voices, Book 2: Sicut cervus',
  'spotify:track:5xkrA4nf1k3gm0LS0ILa1x'),
 ('Hal', 'spotify:track:69qVjSZc2yDnj5TeFxSLpW'),
 ('Give Jah Glory', 'spotify:track:5eJneUnQneHSGkxCKBINZ0'),
 ('Tema a Lo Clan', 'spotify:track:5VlnqsgDtvGUmf5spvayAq'),
 ('Eniaro', 'spotify:track:4PzXygnUYEOPPnc5J3Z5Cy'),
 ('The Line (Hi Fi Bros) - Original', 'spotify:track:5m6mXRfOcdGEfQGwsGIqFj'),
 ('Aquarius', 'spotify:track:5blhQziDkdpy2pQYxoT97Y'),
 ('Invisible Prisons', 'spotify:track:52QLeW0S1iz9fhUIzfTCaN'),
 ('Spin', 'spotify:track:6Z1tVFiGHuN6JYWCVmHow6'),
 ('Twin Rays - Worst Friend Remix', 'spotify:track:72lqDdnhND0wLoqFtAIElO')]